mirror of https://github.com/coqui-ai/TTS.git
Allow saving / loading checkpoints from cloud paths (#683)
* Allow saving / loading checkpoints from cloud paths Allows saving and loading checkpoints directly from cloud paths like Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec. Note: The user will have to install the relevant dependency for each protocol. Otherwise fsspec will fail and specify which dependency is missing. * Append suffix _fsspec to save/load function names * Add a lower bound to the fsspec dependency Skips the 0 major version. * Add missing changes from refactor * Use fsspec for remaining artifacts * Add test case with path requiring fsspec * Avoid writing logs to file unless output_path is local * Document the possibility of using paths supported by fsspec * Fix style and lint * Add missing lint fixes * Add type annotations to new functions * Use Coqpit method for converting config to dict * Fix type annotation in semi-new function * Add return type for load_fsspec * Fix bug where fs not always created * Restore the experiment removal functionality
This commit is contained in:
parent
181177a990
commit
ced4cfdbbf
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||
compare_torch_tf,
|
||||
convert_tf_name,
|
||||
|
@ -33,7 +33,7 @@ num_speakers = 0
|
|||
|
||||
# init torch model
|
||||
model = setup_generator(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.remove_weight_norm()
|
||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
|
|||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.io import load_config, load_fsspec
|
||||
|
||||
sys.path.append("/home/erogol/Projects")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
@ -32,7 +32,7 @@ num_speakers = 0
|
|||
|
||||
# init torch model
|
||||
model = setup_model(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
|
|||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = setup_model(c)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if use_cuda:
|
||||
|
|
|
@ -17,6 +17,7 @@ from TTS.trainer import init_training
|
|||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, check_update
|
||||
|
||||
|
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
checkpoint = load_fsspec(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import re
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
|
|||
def read_json_with_comments(json_path):
|
||||
"""for backward compat."""
|
||||
# fallback to json
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r"\\\n", "", input_str)
|
||||
|
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
|
|||
config_dict = {}
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
elif ext == ".json":
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
data = json.loads(input_str)
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# backwards compat.
|
||||
data = read_json_with_comments(config_path)
|
||||
|
|
|
@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit):
|
|||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
output_path (str):
|
||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
||||
All training outputs are saved there.
|
||||
Path for training output folder, either a local file path or other
|
||||
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||
S3 (s3://) paths. The nonexist part of the given path is created
|
||||
automatically. All training artefacts are saved there.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, proj_size):
|
||||
|
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
|
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -6,11 +6,11 @@ import re
|
|||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy import signal
|
||||
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
class Storage(object):
|
||||
|
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
|
|||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||
|
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
|
|||
bestmodel_path = "best_model.pth.tar"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
|
||||
import torch
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||
|
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
|||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
|
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
|
|||
bestmodel_path = "best_model.pth.tar"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
@ -12,7 +11,9 @@ import traceback
|
|||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed
|
|||
from TTS.utils.generic_utils import (
|
||||
KeepAverage,
|
||||
count_parameters,
|
||||
create_experiment_folder,
|
||||
get_experiment_folder_path,
|
||||
get_git_branch,
|
||||
remove_experiment_folder,
|
||||
set_init_dict,
|
||||
to_cuda,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
|
@ -173,7 +174,6 @@ class Trainer:
|
|||
self.best_loss = float("inf")
|
||||
self.train_loader = None
|
||||
self.eval_loader = None
|
||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||
|
||||
self.keep_avg_train = None
|
||||
self.keep_avg_eval = None
|
||||
|
@ -309,7 +309,7 @@ class Trainer:
|
|||
return obj
|
||||
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = torch.load(restore_path)
|
||||
checkpoint = load_fsspec(restore_path)
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
@ -776,7 +776,7 @@ class Trainer:
|
|||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||
self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
@ -834,9 +834,16 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||
)
|
||||
handlers = [logging.StreamHandler()]
|
||||
|
||||
# Only add a log file if the output location is local due to poor
|
||||
# support for writing logs to file-like objects.
|
||||
parsed_url = urlparse(log_file)
|
||||
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||
handlers.append(logging.FileHandler(schemeless_path))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
|
@ -926,22 +933,27 @@ def init_arguments():
|
|||
return parser
|
||||
|
||||
|
||||
def get_last_checkpoint(path):
|
||||
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
|
||||
Args:
|
||||
path (list): Path to files to be compared.
|
||||
path: Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
|
||||
Returns:
|
||||
last_checkpoint (str): Last checkpoint filename.
|
||||
Path to the last checkpoint
|
||||
Path to best checkpoint
|
||||
"""
|
||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||
scheme = urlparse(path).scheme
|
||||
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
|
@ -963,7 +975,7 @@ def get_last_checkpoint(path):
|
|||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = torch.load(last_model)["step"]
|
||||
last_model_num = load_fsspec(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
|
@ -1030,12 +1042,11 @@ def process_args(args, config=None):
|
|||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = create_experiment_folder(config.output_path, config.run_name)
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
if args.rank == 0:
|
||||
os.makedirs(audio_path, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
|
@ -1047,8 +1058,6 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(experiment_path, 0o775)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
|
@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if "r" in state:
|
||||
self.decoder.set_r(state["r"])
|
||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
|
|||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class GlowTTS(BaseTTS):
|
||||
|
@ -382,7 +383,7 @@ class GlowTTS(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
|
|||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import importlib
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
|||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
|
|||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
|
|||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with open(output_path, "wb") as f:
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import random
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -84,12 +85,12 @@ class SpeakerManager:
|
|||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with open(json_file_path) as f:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with open(json_file_path, "w") as f:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
|
@ -294,9 +295,10 @@ def _set_file_path(path):
|
|||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||
path_continue = os.path.join(path, "speakers.json")
|
||||
if os.path.exists(path_restore):
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
if fs.exists(path_restore):
|
||||
return path_restore
|
||||
if os.path.exists(path_continue):
|
||||
if fs.exists(path_continue):
|
||||
return path_continue
|
||||
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||
|
||||
|
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
|
|||
json_file = out_path
|
||||
else:
|
||||
json_file = _set_file_path(out_path)
|
||||
with open(json_file) as f:
|
||||
with fsspec.open(json_file, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
|
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
|||
"""Saves speaker mapping if not yet present."""
|
||||
if out_path is not None:
|
||||
speakers_json_path = _set_file_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
with fsspec.open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -58,23 +57,22 @@ def get_commit_hash():
|
|||
return commit
|
||||
|
||||
|
||||
def create_experiment_folder(root_path, model_name):
|
||||
"""Create a folder with the current date and time"""
|
||||
def get_experiment_folder_path(root_path, model_name):
|
||||
"""Get an experiment folder path with the current date and time"""
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||
fs = fsspec.get_mapper(experiment_path).fs
|
||||
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
|
||||
if not checkpoint_files:
|
||||
if os.path.exists(experiment_path):
|
||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
||||
if fs.exists(experiment_path):
|
||||
fs.rm(experiment_path, recursive=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
from shutil import copyfile
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
@ -24,7 +26,7 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
def copy_model_files(config, out_path, new_fields):
|
||||
def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||
"""Copy config.json and other model files to training folder and add
|
||||
new fields.
|
||||
|
||||
|
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
|
|||
copy_config_path = os.path.join(out_path, "config.json")
|
||||
# add extra information fields
|
||||
config.update(new_fields, allow_new=True)
|
||||
config.save_json(copy_config_path)
|
||||
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
||||
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
||||
json.dump(config.to_dict(), f, indent=4)
|
||||
|
||||
# copy model stats file if available
|
||||
if config.audio.stats_path is not None:
|
||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||
if not os.path.exists(copy_stats_path):
|
||||
copyfile(
|
||||
config.audio.stats_path,
|
||||
copy_stats_path,
|
||||
)
|
||||
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
||||
if not filesystem.exists(copy_stats_path):
|
||||
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
||||
with fsspec.open(copy_stats_path, "wb") as target_file:
|
||||
shutil.copyfileobj(source_file, target_file)
|
||||
|
||||
|
||||
def load_fsspec(path: str, **kwargs) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
|
|||
return model, state
|
||||
|
||||
|
||||
def save_fsspec(state: Any, path: str, **kwargs):
|
||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
state: State object to save
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.save.
|
||||
"""
|
||||
with fsspec.open(path, "wb") as f:
|
||||
torch.save(state, f, **kwargs)
|
||||
|
||||
|
||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
|
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
|
|||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
save_fsspec(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
|
@ -147,18 +178,16 @@ def save_best_model(
|
|||
model_loss=current_loss,
|
||||
**kwargs,
|
||||
)
|
||||
fs = fsspec.get_mapper(out_path).fs
|
||||
# only delete previous if current is saved successfully
|
||||
if not keep_all_best or (current_step < keep_after):
|
||||
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||
for model_name in model_names:
|
||||
if os.path.basename(model_name) == best_model_name:
|
||||
continue
|
||||
os.remove(model_name)
|
||||
# create symlink to best model for convinience
|
||||
link_name = "best_model.pth.tar"
|
||||
link_path = os.path.join(out_path, link_name)
|
||||
if os.path.islink(link_path) or os.path.isfile(link_path):
|
||||
os.remove(link_path)
|
||||
os.symlink(best_model_name, os.path.join(out_path, link_name))
|
||||
if os.path.basename(model_name) != best_model_name:
|
||||
fs.rm(model_name)
|
||||
# create a shortcut which always points to the currently best model
|
||||
shortcut_name = "best_model.pth.tar"
|
||||
shortcut_path = os.path.join(out_path, shortcut_name)
|
||||
fs.copy(checkpoint_path, shortcut_path)
|
||||
best_loss = current_loss
|
||||
return best_loss
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
|
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
|
|||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
|
|
|
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
|||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
|
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
|
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
|
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
|
|||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
|
|||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
"""Load TF Vocoder model"""
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
with fsspec.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = pickle.load(f)
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import fsspec
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
|
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
|
|||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with open(output_path, "wb") as f:
|
||||
with fsspec.open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
|
|
@ -24,3 +24,4 @@ mecab-python3==1.0.3
|
|||
unidic-lite==1.0.8
|
||||
# gruut+supported langs
|
||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
|
||||
fsspec>=2021.04.0
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import Tacotron2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
config = Tacotron2Config(
|
||||
r=5,
|
||||
batch_size=8,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=False,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
print_eval=True,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
|
||||
f"--coqpit.output_path file://{output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.test_delay_epochs 0 "
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
|
||||
)
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue