mirror of https://github.com/coqui-ai/TTS.git
Allow saving / loading checkpoints from cloud paths (#683)
* Allow saving / loading checkpoints from cloud paths Allows saving and loading checkpoints directly from cloud paths like Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec. Note: The user will have to install the relevant dependency for each protocol. Otherwise fsspec will fail and specify which dependency is missing. * Append suffix _fsspec to save/load function names * Add a lower bound to the fsspec dependency Skips the 0 major version. * Add missing changes from refactor * Use fsspec for remaining artifacts * Add test case with path requiring fsspec * Avoid writing logs to file unless output_path is local * Document the possibility of using paths supported by fsspec * Fix style and lint * Add missing lint fixes * Add type annotations to new functions * Use Coqpit method for converting config to dict * Fix type annotation in semi-new function * Add return type for load_fsspec * Fix bug where fs not always created * Restore the experiment removal functionality
This commit is contained in:
parent
181177a990
commit
ced4cfdbbf
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config, load_fsspec
|
||||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||||
compare_torch_tf,
|
compare_torch_tf,
|
||||||
convert_tf_name,
|
convert_tf_name,
|
||||||
|
@ -33,7 +33,7 @@ num_speakers = 0
|
||||||
|
|
||||||
# init torch model
|
# init torch model
|
||||||
model = setup_generator(c)
|
model = setup_generator(c)
|
||||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.remove_weight_norm()
|
model.remove_weight_norm()
|
||||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config, load_fsspec
|
||||||
|
|
||||||
sys.path.append("/home/erogol/Projects")
|
sys.path.append("/home/erogol/Projects")
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
@ -32,7 +32,7 @@ num_speakers = 0
|
||||||
|
|
||||||
# init torch model
|
# init torch model
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||||
state_dict = checkpoint["model"]
|
state_dict = checkpoint["model"]
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
|
|
||||||
# restore model
|
# restore model
|
||||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
|
|
@ -17,6 +17,7 @@ from TTS.trainer import init_training
|
||||||
from TTS.tts.datasets import load_meta_data
|
from TTS.tts.datasets import load_meta_data
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.training import NoamLR, check_update
|
from TTS.utils.training import NoamLR, check_update
|
||||||
|
|
||||||
|
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = load_fsspec(args.restore_path)
|
||||||
try:
|
try:
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import yaml
|
import yaml
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
|
||||||
def read_json_with_comments(json_path):
|
def read_json_with_comments(json_path):
|
||||||
"""for backward compat."""
|
"""for backward compat."""
|
||||||
# fallback to json
|
# fallback to json
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
input_str = f.read()
|
||||||
# handle comments
|
# handle comments
|
||||||
input_str = re.sub(r"\\\n", "", input_str)
|
input_str = re.sub(r"\\\n", "", input_str)
|
||||||
|
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
ext = os.path.splitext(config_path)[1]
|
ext = os.path.splitext(config_path)[1]
|
||||||
if ext in (".yml", ".yaml"):
|
if ext in (".yml", ".yaml"):
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
elif ext == ".json":
|
elif ext == ".json":
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
data = json.load(f)
|
||||||
data = json.loads(input_str)
|
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
# backwards compat.
|
# backwards compat.
|
||||||
data = read_json_with_comments(config_path)
|
data = read_json_with_comments(config_path)
|
||||||
|
|
|
@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit):
|
||||||
num_eval_loader_workers (int):
|
num_eval_loader_workers (int):
|
||||||
Number of workers for evaluation time dataloader.
|
Number of workers for evaluation time dataloader.
|
||||||
output_path (str):
|
output_path (str):
|
||||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
Path for training output folder, either a local file path or other
|
||||||
All training outputs are saved there.
|
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||||
|
S3 (s3://) paths. The nonexist part of the given path is created
|
||||||
|
automatically. All training artefacts are saved there.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = None
|
model: str = None
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class LSTMWithProjection(nn.Module):
|
class LSTMWithProjection(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, proj_size):
|
def __init__(self, input_size, hidden_size, proj_size):
|
||||||
|
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
# pylint: disable=unused-argument, redefined-builtin
|
# pylint: disable=unused-argument, redefined-builtin
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.cuda()
|
self.cuda()
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
|
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.cuda()
|
self.cuda()
|
||||||
|
|
|
@ -6,11 +6,11 @@ import re
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
|
from TTS.utils.io import save_fsspec
|
||||||
|
|
||||||
|
|
||||||
class Storage(object):
|
class Storage(object):
|
||||||
|
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
|
||||||
"loss": model_loss,
|
"loss": model_loss,
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
torch.save(state, checkpoint_path)
|
save_fsspec(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||||
|
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
|
||||||
bestmodel_path = "best_model.pth.tar"
|
bestmodel_path = "best_model.pth.tar"
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
save_fsspec(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
from TTS.utils.io import save_fsspec
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||||
|
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||||
"loss": model_loss,
|
"loss": model_loss,
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
torch.save(state, checkpoint_path)
|
save_fsspec(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||||
|
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
|
||||||
bestmodel_path = "best_model.pth.tar"
|
bestmodel_path = "best_model.pth.tar"
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
save_fsspec(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import glob
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -12,7 +11,9 @@ import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
KeepAverage,
|
KeepAverage,
|
||||||
count_parameters,
|
count_parameters,
|
||||||
create_experiment_folder,
|
get_experiment_folder_path,
|
||||||
get_git_branch,
|
get_git_branch,
|
||||||
remove_experiment_folder,
|
remove_experiment_folder,
|
||||||
set_init_dict,
|
set_init_dict,
|
||||||
to_cuda,
|
to_cuda,
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
@ -173,7 +174,6 @@ class Trainer:
|
||||||
self.best_loss = float("inf")
|
self.best_loss = float("inf")
|
||||||
self.train_loader = None
|
self.train_loader = None
|
||||||
self.eval_loader = None
|
self.eval_loader = None
|
||||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
|
||||||
|
|
||||||
self.keep_avg_train = None
|
self.keep_avg_train = None
|
||||||
self.keep_avg_eval = None
|
self.keep_avg_eval = None
|
||||||
|
@ -309,7 +309,7 @@ class Trainer:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
checkpoint = torch.load(restore_path)
|
checkpoint = load_fsspec(restore_path)
|
||||||
try:
|
try:
|
||||||
print(" > Restoring Model...")
|
print(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
@ -776,7 +776,7 @@ class Trainer:
|
||||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"]
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
@ -834,9 +834,16 @@ class Trainer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _setup_logger_config(log_file: str) -> None:
|
def _setup_logger_config(log_file: str) -> None:
|
||||||
logging.basicConfig(
|
handlers = [logging.StreamHandler()]
|
||||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
|
||||||
)
|
# Only add a log file if the output location is local due to poor
|
||||||
|
# support for writing logs to file-like objects.
|
||||||
|
parsed_url = urlparse(log_file)
|
||||||
|
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||||
|
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||||
|
handlers.append(logging.FileHandler(schemeless_path))
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_apex_available() -> bool:
|
def _is_apex_available() -> bool:
|
||||||
|
@ -926,22 +933,27 @@ def init_arguments():
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_last_checkpoint(path):
|
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||||
"""Get latest checkpoint or/and best model in path.
|
"""Get latest checkpoint or/and best model in path.
|
||||||
|
|
||||||
It is based on globbing for `*.pth.tar` and the RegEx
|
It is based on globbing for `*.pth.tar` and the RegEx
|
||||||
`(checkpoint|best_model)_([0-9]+)`.
|
`(checkpoint|best_model)_([0-9]+)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (list): Path to files to be compared.
|
path: Path to files to be compared.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no checkpoint or best_model files are found.
|
ValueError: If no checkpoint or best_model files are found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
last_checkpoint (str): Last checkpoint filename.
|
Path to the last checkpoint
|
||||||
|
Path to best checkpoint
|
||||||
"""
|
"""
|
||||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||||
|
scheme = urlparse(path).scheme
|
||||||
|
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||||
|
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||||
last_models = {}
|
last_models = {}
|
||||||
last_model_nums = {}
|
last_model_nums = {}
|
||||||
for key in ["checkpoint", "best_model"]:
|
for key in ["checkpoint", "best_model"]:
|
||||||
|
@ -963,7 +975,7 @@ def get_last_checkpoint(path):
|
||||||
key_file_names = [fn for fn in file_names if key in fn]
|
key_file_names = [fn for fn in file_names if key in fn]
|
||||||
if last_model is None and len(key_file_names) > 0:
|
if last_model is None and len(key_file_names) > 0:
|
||||||
last_model = max(key_file_names, key=os.path.getctime)
|
last_model = max(key_file_names, key=os.path.getctime)
|
||||||
last_model_num = torch.load(last_model)["step"]
|
last_model_num = load_fsspec(last_model)["step"]
|
||||||
|
|
||||||
if last_model is not None:
|
if last_model is not None:
|
||||||
last_models[key] = last_model
|
last_models[key] = last_model
|
||||||
|
@ -1030,12 +1042,11 @@ def process_args(args, config=None):
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path, config.run_name)
|
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
tb_logger = None
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
os.makedirs(audio_path, exist_ok=True)
|
|
||||||
new_fields = {}
|
new_fields = {}
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
|
@ -1047,8 +1058,6 @@ def process_args(args, config=None):
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
os.chmod(audio_path, 0o775)
|
|
||||||
os.chmod(experiment_path, 0o775)
|
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if "r" in state:
|
if "r" in state:
|
||||||
self.decoder.set_r(state["r"])
|
self.decoder.set_r(state["r"])
|
||||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class GlowTTS(BaseTTS):
|
class GlowTTS(BaseTTS):
|
||||||
|
@ -382,7 +383,7 @@ class GlowTTS(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -2,6 +2,7 @@ import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
||||||
"r": r,
|
"r": r,
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
||||||
"r": r,
|
"r": r,
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
|
||||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
# same model binary if outputpath is provided
|
# same model binary if outputpath is provided
|
||||||
with open(output_path, "wb") as f:
|
with fsspec.open(output_path, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
return None
|
return None
|
||||||
return tflite_model
|
return tflite_model
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -84,12 +85,12 @@ class SpeakerManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_json(json_file_path: str) -> Dict:
|
def _load_json(json_file_path: str) -> Dict:
|
||||||
with open(json_file_path) as f:
|
with fsspec.open(json_file_path, "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_json(json_file_path: str, data: dict) -> None:
|
def _save_json(json_file_path: str, data: dict) -> None:
|
||||||
with open(json_file_path, "w") as f:
|
with fsspec.open(json_file_path, "w") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -294,9 +295,10 @@ def _set_file_path(path):
|
||||||
Intended to band aid the different paths returned in restored and continued training."""
|
Intended to band aid the different paths returned in restored and continued training."""
|
||||||
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||||
path_continue = os.path.join(path, "speakers.json")
|
path_continue = os.path.join(path, "speakers.json")
|
||||||
if os.path.exists(path_restore):
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
if fs.exists(path_restore):
|
||||||
return path_restore
|
return path_restore
|
||||||
if os.path.exists(path_continue):
|
if fs.exists(path_continue):
|
||||||
return path_continue
|
return path_continue
|
||||||
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||||
|
|
||||||
|
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
|
||||||
json_file = out_path
|
json_file = out_path
|
||||||
else:
|
else:
|
||||||
json_file = _set_file_path(out_path)
|
json_file = _set_file_path(out_path)
|
||||||
with open(json_file) as f:
|
with fsspec.open(json_file, "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
"""Saves speaker mapping if not yet present."""
|
"""Saves speaker mapping if not yet present."""
|
||||||
if out_path is not None:
|
if out_path is not None:
|
||||||
speakers_json_path = _set_file_path(out_path)
|
speakers_json_path = _set_file_path(out_path)
|
||||||
with open(speakers_json_path, "w") as f:
|
with fsspec.open(speakers_json_path, "w") as f:
|
||||||
json.dump(speaker_mapping, f, indent=4)
|
json.dump(speaker_mapping, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,23 +57,22 @@ def get_commit_hash():
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path, model_name):
|
def get_experiment_folder_path(root_path, model_name):
|
||||||
"""Create a folder with the current date and time"""
|
"""Get an experiment folder path with the current date and time"""
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
print(" > Experiment folder: {}".format(output_folder))
|
||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
|
|
||||||
def remove_experiment_folder(experiment_path):
|
def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
fs = fsspec.get_mapper(experiment_path).fs
|
||||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
|
||||||
if not checkpoint_files:
|
if not checkpoint_files:
|
||||||
if os.path.exists(experiment_path):
|
if fs.exists(experiment_path):
|
||||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
fs.rm(experiment_path, recursive=True)
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
else:
|
else:
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
from shutil import copyfile
|
import shutil
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
@ -24,7 +26,7 @@ class AttrDict(dict):
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(config, out_path, new_fields):
|
def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||||
"""Copy config.json and other model files to training folder and add
|
"""Copy config.json and other model files to training folder and add
|
||||||
new fields.
|
new fields.
|
||||||
|
|
||||||
|
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
|
||||||
copy_config_path = os.path.join(out_path, "config.json")
|
copy_config_path = os.path.join(out_path, "config.json")
|
||||||
# add extra information fields
|
# add extra information fields
|
||||||
config.update(new_fields, allow_new=True)
|
config.update(new_fields, allow_new=True)
|
||||||
config.save_json(copy_config_path)
|
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
||||||
|
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
||||||
|
json.dump(config.to_dict(), f, indent=4)
|
||||||
|
|
||||||
# copy model stats file if available
|
# copy model stats file if available
|
||||||
if config.audio.stats_path is not None:
|
if config.audio.stats_path is not None:
|
||||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||||
if not os.path.exists(copy_stats_path):
|
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
||||||
copyfile(
|
if not filesystem.exists(copy_stats_path):
|
||||||
config.audio.stats_path,
|
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
||||||
copy_stats_path,
|
with fsspec.open(copy_stats_path, "wb") as target_file:
|
||||||
)
|
shutil.copyfileobj(source_file, target_file)
|
||||||
|
|
||||||
|
|
||||||
|
def load_fsspec(path: str, **kwargs) -> Any:
|
||||||
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Any path or url supported by fsspec.
|
||||||
|
**kwargs: Keyword arguments forwarded to torch.load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object stored in path.
|
||||||
|
"""
|
||||||
|
with fsspec.open(path, "rb") as f:
|
||||||
|
return torch.load(f, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||||
try:
|
try:
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pickle_tts.Unpickler = RenamingUnpickler
|
pickle_tts.Unpickler = RenamingUnpickler
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||||
model.load_state_dict(state["model"])
|
model.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
|
def save_fsspec(state: Any, path: str, **kwargs):
|
||||||
|
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State object to save
|
||||||
|
path: Any path or url supported by fsspec.
|
||||||
|
**kwargs: Keyword arguments forwarded to torch.save.
|
||||||
|
"""
|
||||||
|
with fsspec.open(path, "wb") as f:
|
||||||
|
torch.save(state, f, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
model_state = model.module.state_dict()
|
model_state = model.module.state_dict()
|
||||||
|
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
torch.save(state, output_path)
|
save_fsspec(state, output_path)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
|
@ -147,18 +178,16 @@ def save_best_model(
|
||||||
model_loss=current_loss,
|
model_loss=current_loss,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
fs = fsspec.get_mapper(out_path).fs
|
||||||
# only delete previous if current is saved successfully
|
# only delete previous if current is saved successfully
|
||||||
if not keep_all_best or (current_step < keep_after):
|
if not keep_all_best or (current_step < keep_after):
|
||||||
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||||
for model_name in model_names:
|
for model_name in model_names:
|
||||||
if os.path.basename(model_name) == best_model_name:
|
if os.path.basename(model_name) != best_model_name:
|
||||||
continue
|
fs.rm(model_name)
|
||||||
os.remove(model_name)
|
# create a shortcut which always points to the currently best model
|
||||||
# create symlink to best model for convinience
|
shortcut_name = "best_model.pth.tar"
|
||||||
link_name = "best_model.pth.tar"
|
shortcut_path = os.path.join(out_path, shortcut_name)
|
||||||
link_path = os.path.join(out_path, link_name)
|
fs.copy(checkpoint_path, shortcut_path)
|
||||||
if os.path.islink(link_path) or os.path.isfile(link_path):
|
|
||||||
os.remove(link_path)
|
|
||||||
os.symlink(best_model_name, os.path.join(out_path, link_name))
|
|
||||||
best_loss = current_loss
|
best_loss = current_loss
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
|
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
|
||||||
checkpoint_path (str): Checkpoint file path.
|
checkpoint_path (str): Checkpoint file path.
|
||||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||||
"""
|
"""
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
# band-aid for older than v0.0.15 GAN models
|
# band-aid for older than v0.0.15 GAN models
|
||||||
if "model_disc" in state:
|
if "model_disc" in state:
|
||||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||||
|
|
|
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.melgan import ResidualStack
|
from TTS.vocoder.layers.melgan import ResidualStack
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||||
|
|
||||||
|
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.datasets import WaveGradDataset
|
from TTS.vocoder.datasets import WaveGradDataset
|
||||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||||
|
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||||
|
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
|
||||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(output_path, "wb"))
|
with fsspec.open(output_path, "wb") as f:
|
||||||
|
pickle.dump(state, f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
"""Load TF Vocoder model"""
|
"""Load TF Vocoder model"""
|
||||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
with fsspec.open(checkpoint_path, "rb") as f:
|
||||||
|
checkpoint = pickle.load(f)
|
||||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import fsspec
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
|
||||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
# same model binary if outputpath is provided
|
# same model binary if outputpath is provided
|
||||||
with open(output_path, "wb") as f:
|
with fsspec.open(output_path, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
return None
|
return None
|
||||||
return tflite_model
|
return tflite_model
|
||||||
|
|
|
@ -24,3 +24,4 @@ mecab-python3==1.0.3
|
||||||
unidic-lite==1.0.8
|
unidic-lite==1.0.8
|
||||||
# gruut+supported langs
|
# gruut+supported langs
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
||||||
|
fsspec>=2021.04.0
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import Tacotron2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
config = Tacotron2Config(
|
||||||
|
r=5,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
print_eval=True,
|
||||||
|
max_decoder_steps=50,
|
||||||
|
)
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
|
||||||
|
f"--coqpit.output_path file://{output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.test_delay_epochs 0 "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue