From ced4cfdbbf4496ec0d95483e470933b4fed4f95a Mon Sep 17 00:00:00 2001 From: Agrin Hilmkil Date: Thu, 5 Aug 2021 14:11:26 +0200 Subject: [PATCH] Allow saving / loading checkpoints from cloud paths (#683) * Allow saving / loading checkpoints from cloud paths Allows saving and loading checkpoints directly from cloud paths like Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec. Note: The user will have to install the relevant dependency for each protocol. Otherwise fsspec will fail and specify which dependency is missing. * Append suffix _fsspec to save/load function names * Add a lower bound to the fsspec dependency Skips the 0 major version. * Add missing changes from refactor * Use fsspec for remaining artifacts * Add test case with path requiring fsspec * Avoid writing logs to file unless output_path is local * Document the possibility of using paths supported by fsspec * Fix style and lint * Add missing lint fixes * Add type annotations to new functions * Use Coqpit method for converting config to dict * Fix type annotation in semi-new function * Add return type for load_fsspec * Fix bug where fs not always created * Restore the experiment removal functionality --- TTS/bin/convert_melgan_torch_to_tf.py | 4 +- TTS/bin/convert_tacotron2_torch_to_tf.py | 4 +- TTS/bin/extract_tts_spectrograms.py | 3 +- TTS/bin/train_encoder.py | 3 +- TTS/config/__init__.py | 10 +-- TTS/config/shared_configs.py | 6 +- TTS/speaker_encoder/models/lstm.py | 4 +- TTS/speaker_encoder/models/resnet.py | 4 +- TTS/speaker_encoder/utils/generic_utils.py | 6 +- TTS/speaker_encoder/utils/io.py | 6 +- TTS/trainer.py | 45 +++++++----- TTS/tts/models/align_tts.py | 3 +- TTS/tts/models/base_tacotron.py | 3 +- TTS/tts/models/glow_tts.py | 3 +- TTS/tts/models/speedy_speech.py | 3 +- TTS/tts/tf/utils/generic_utils.py | 7 +- TTS/tts/tf/utils/io.py | 7 +- TTS/tts/tf/utils/tflite.py | 3 +- TTS/tts/utils/speakers.py | 14 ++-- TTS/utils/generic_utils.py | 16 ++-- TTS/utils/io.py | 73 +++++++++++++------ TTS/vocoder/models/gan.py | 3 +- TTS/vocoder/models/hifigan_generator.py | 4 +- TTS/vocoder/models/melgan_generator.py | 3 +- .../models/parallel_wavegan_generator.py | 3 +- TTS/vocoder/models/wavegrad.py | 3 +- TTS/vocoder/models/wavernn.py | 3 +- TTS/vocoder/tf/utils/io.py | 7 +- TTS/vocoder/tf/utils/tflite.py | 3 +- requirements.txt | 1 + .../test_tacotron2_train_fsspec_path.py | 55 ++++++++++++++ 31 files changed, 218 insertions(+), 94 deletions(-) create mode 100644 tests/tts_tests/test_tacotron2_train_fsspec_path.py diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py index 43581348..c1fb8498 100644 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ b/TTS/bin/convert_melgan_torch_to_tf.py @@ -6,7 +6,7 @@ import numpy as np import tensorflow as tf import torch -from TTS.utils.io import load_config +from TTS.utils.io import load_config, load_fsspec from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( compare_torch_tf, convert_tf_name, @@ -33,7 +33,7 @@ num_speakers = 0 # init torch model model = setup_generator(c) -checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) state_dict = checkpoint["model"] model.load_state_dict(state_dict) model.remove_weight_norm() diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index a6fb5d9b..78c6b362 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config +from TTS.utils.io import load_config, load_fsspec sys.path.append("/home/erogol/Projects") os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -32,7 +32,7 @@ num_speakers = 0 # init torch model model = setup_model(c) -checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) state_dict = checkpoint["model"] model.load_state_dict(state_dict) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 1cbc5516..debe5933 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -16,6 +16,7 @@ from TTS.tts.models import setup_model from TTS.tts.utils.speakers import get_speaker_manager from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters +from TTS.utils.io import load_fsspec use_cuda = torch.cuda.is_available() @@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name model = setup_model(c) # restore model - checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) if use_cuda: diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 2bb5bfc7..43867239 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -17,6 +17,7 @@ from TTS.trainer import init_training from TTS.tts.datasets import load_meta_data from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict +from TTS.utils.io import load_fsspec from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, check_update @@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name raise Exception("The %s not is a loss supported" % c.loss) if args.restore_path: - checkpoint = torch.load(args.restore_path) + checkpoint = load_fsspec(args.restore_path) try: model.load_state_dict(checkpoint["model"]) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index ecbe1f9a..ea98f431 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -3,6 +3,7 @@ import os import re from typing import Dict +import fsspec import yaml from coqpit import Coqpit @@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module def read_json_with_comments(json_path): """for backward compat.""" # fallback to json - with open(json_path, "r", encoding="utf-8") as f: + with fsspec.open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments input_str = re.sub(r"\\\n", "", input_str) @@ -76,13 +77,12 @@ def load_config(config_path: str) -> None: config_dict = {} ext = os.path.splitext(config_path)[1] if ext in (".yml", ".yaml"): - with open(config_path, "r", encoding="utf-8") as f: + with fsspec.open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) elif ext == ".json": try: - with open(config_path, "r", encoding="utf-8") as f: - input_str = f.read() - data = json.loads(input_str) + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) except json.decoder.JSONDecodeError: # backwards compat. data = read_json_with_comments(config_path) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 669437b9..0ec7f758 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit): num_eval_loader_workers (int): Number of workers for evaluation time dataloader. output_path (str): - Path for training output folder. The nonexist part of the given path is created automatically. - All training outputs are saved there. + Path for training output folder, either a local file path or other + URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or + S3 (s3://) paths. The nonexist part of the given path is created + automatically. All training artefacts are saved there. """ model: str = None diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index 7e39087a..de5bb007 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -2,6 +2,8 @@ import numpy as np import torch from torch import nn +from TTS.utils.io import load_fsspec + class LSTMWithProjection(nn.Module): def __init__(self, input_size, hidden_size, proj_size): @@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if use_cuda: self.cuda() diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index f52bb4d5..f121631b 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -2,6 +2,8 @@ import numpy as np import torch import torch.nn as nn +from TTS.utils.io import load_fsspec + class SELayer(nn.Module): def __init__(self, channel, reduction=8): @@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module): return embeddings def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if use_cuda: self.cuda() diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index fb61e48e..1981fbe9 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -6,11 +6,11 @@ import re from multiprocessing import Manager import numpy as np -import torch from scipy import signal from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder +from TTS.utils.io import save_fsspec class Storage(object): @@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s "loss": model_loss, "date": datetime.date.today().strftime("%B %d, %Y"), } - torch.save(state, checkpoint_path) + save_fsspec(state, checkpoint_path) def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): @@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - torch.save(state, bestmodel_path) + save_fsspec(state, bestmodel_path) return best_loss diff --git a/TTS/speaker_encoder/utils/io.py b/TTS/speaker_encoder/utils/io.py index 0479f1af..7a3aadc9 100644 --- a/TTS/speaker_encoder/utils/io.py +++ b/TTS/speaker_encoder/utils/io.py @@ -1,7 +1,7 @@ import datetime import os -import torch +from TTS.utils.io import save_fsspec def save_checkpoint(model, optimizer, model_loss, out_path, current_step): @@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step): "loss": model_loss, "date": datetime.date.today().strftime("%B %d, %Y"), } - torch.save(state, checkpoint_path) + save_fsspec(state, checkpoint_path) def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): @@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - torch.save(state, bestmodel_path) + save_fsspec(state, bestmodel_path) return best_loss diff --git a/TTS/trainer.py b/TTS/trainer.py index 903aee5f..3ac83601 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import glob import importlib import logging import os @@ -12,7 +11,9 @@ import traceback from argparse import Namespace from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union +from urllib.parse import urlparse +import fsspec import torch from coqpit import Coqpit from torch import nn @@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import ( KeepAverage, count_parameters, - create_experiment_folder, + get_experiment_folder_path, get_git_branch, remove_experiment_folder, set_init_dict, to_cuda, ) -from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint +from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data @@ -173,7 +174,6 @@ class Trainer: self.best_loss = float("inf") self.train_loader = None self.eval_loader = None - self.output_audio_path = os.path.join(output_path, "test_audios") self.keep_avg_train = None self.keep_avg_eval = None @@ -309,7 +309,7 @@ class Trainer: return obj print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = torch.load(restore_path) + checkpoint = load_fsspec(restore_path) try: print(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) @@ -776,7 +776,7 @@ class Trainer: """🏃 train -> evaluate -> test for the number of epochs.""" if self.restore_step != 0 or self.args.best_path: print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] + self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") self.total_steps_done = self.restore_step @@ -834,9 +834,16 @@ class Trainer: @staticmethod def _setup_logger_config(log_file: str) -> None: - logging.basicConfig( - level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] - ) + handlers = [logging.StreamHandler()] + + # Only add a log file if the output location is local due to poor + # support for writing logs to file-like objects. + parsed_url = urlparse(log_file) + if not parsed_url.scheme or parsed_url.scheme == "file": + schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path) + handlers.append(logging.FileHandler(schemeless_path)) + + logging.basicConfig(level=logging.INFO, format="", handlers=handlers) @staticmethod def _is_apex_available() -> bool: @@ -926,22 +933,27 @@ def init_arguments(): return parser -def get_last_checkpoint(path): +def get_last_checkpoint(path: str) -> Tuple[str, str]: """Get latest checkpoint or/and best model in path. It is based on globbing for `*.pth.tar` and the RegEx `(checkpoint|best_model)_([0-9]+)`. Args: - path (list): Path to files to be compared. + path: Path to files to be compared. Raises: ValueError: If no checkpoint or best_model files are found. Returns: - last_checkpoint (str): Last checkpoint filename. + Path to the last checkpoint + Path to best checkpoint """ - file_names = glob.glob(os.path.join(path, "*.pth.tar")) + fs = fsspec.get_mapper(path).fs + file_names = fs.glob(os.path.join(path, "*.pth.tar")) + scheme = urlparse(path).scheme + if scheme: # scheme is not preserved in fs.glob, add it back + file_names = [scheme + "://" + file_name for file_name in file_names] last_models = {} last_model_nums = {} for key in ["checkpoint", "best_model"]: @@ -963,7 +975,7 @@ def get_last_checkpoint(path): key_file_names = [fn for fn in file_names if key in fn] if last_model is None and len(key_file_names) > 0: last_model = max(key_file_names, key=os.path.getctime) - last_model_num = torch.load(last_model)["step"] + last_model_num = load_fsspec(last_model)["step"] if last_model is not None: last_models[key] = last_model @@ -1030,12 +1042,11 @@ def process_args(args, config=None): print(" > Mixed precision mode is ON") experiment_path = args.continue_path if not experiment_path: - experiment_path = create_experiment_folder(config.output_path, config.run_name) + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) audio_path = os.path.join(experiment_path, "test_audios") # setup rank 0 process in distributed training tb_logger = None if args.rank == 0: - os.makedirs(audio_path, exist_ok=True) new_fields = {} if args.restore_path: new_fields["restore_path"] = args.restore_path @@ -1047,8 +1058,6 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - os.chmod(audio_path, 0o775) - os.chmod(experiment_path, 0o775) tb_logger = TensorboardLogger(experiment_path, model_name=config.model) # write model desc to tensorboard tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 879ecae4..fb2fa697 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec @dataclass @@ -389,7 +390,7 @@ class AlignTTS(BaseTTS): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index b7056e06..2d2cc111 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.text import make_symbols from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler @@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if "r" in state: self.decoder.set_r(state["r"]) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index b3bceb09..1c631c8e 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec class GlowTTS(BaseTTS): @@ -382,7 +383,7 @@ class GlowTTS(BaseTTS): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 8f14d610..33b9cb66 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec @dataclass @@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py index 91434a38..681a9457 100644 --- a/TTS/tts/tf/utils/generic_utils.py +++ b/TTS/tts/tf/utils/generic_utils.py @@ -2,6 +2,7 @@ import datetime import importlib import pickle +import fsspec import numpy as np import tensorflow as tf @@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, "wb")) + with fsspec.open(output_path, "wb") as f: + pickle.dump(state, f) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, "rb")) + with fsspec.open(checkpoint_path, "rb") as f: + checkpoint = pickle.load(f) chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py index b2345b00..de6acff9 100644 --- a/TTS/tts/tf/utils/io.py +++ b/TTS/tts/tf/utils/io.py @@ -1,6 +1,7 @@ import datetime import pickle +import fsspec import tensorflow as tf @@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, "wb")) + with fsspec.open(output_path, "wb") as f: + pickle.dump(state, f) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, "rb")) + with fsspec.open(checkpoint_path, "rb") as f: + checkpoint = pickle.load(f) chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py index 9701d591..2f76aa50 100644 --- a/TTS/tts/tf/utils/tflite.py +++ b/TTS/tts/tf/utils/tflite.py @@ -1,3 +1,4 @@ +import fsspec import tensorflow as tf @@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter= print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, "wb") as f: + with fsspec.open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index a8c9e0f6..ed14cd8e 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -3,6 +3,7 @@ import os import random from typing import Any, Dict, List, Tuple, Union +import fsspec import numpy as np import torch from coqpit import Coqpit @@ -84,12 +85,12 @@ class SpeakerManager: @staticmethod def _load_json(json_file_path: str) -> Dict: - with open(json_file_path) as f: + with fsspec.open(json_file_path, "r") as f: return json.load(f) @staticmethod def _save_json(json_file_path: str, data: dict) -> None: - with open(json_file_path, "w") as f: + with fsspec.open(json_file_path, "w") as f: json.dump(data, f, indent=4) @property @@ -294,9 +295,10 @@ def _set_file_path(path): Intended to band aid the different paths returned in restored and continued training.""" path_restore = os.path.join(os.path.dirname(path), "speakers.json") path_continue = os.path.join(path, "speakers.json") - if os.path.exists(path_restore): + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): return path_restore - if os.path.exists(path_continue): + if fs.exists(path_continue): return path_continue raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") @@ -307,7 +309,7 @@ def load_speaker_mapping(out_path): json_file = out_path else: json_file = _set_file_path(out_path) - with open(json_file) as f: + with fsspec.open(json_file, "r") as f: return json.load(f) @@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" if out_path is not None: speakers_json_path = _set_file_path(out_path) - with open(speakers_json_path, "w") as f: + with fsspec.open(speakers_json_path, "w") as f: json.dump(speaker_mapping, f, indent=4) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index e7c57529..287143e5 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -1,15 +1,14 @@ # -*- coding: utf-8 -*- import datetime -import glob import importlib import os import re -import shutil import subprocess import sys from pathlib import Path from typing import Dict +import fsspec import torch @@ -58,23 +57,22 @@ def get_commit_hash(): return commit -def create_experiment_folder(root_path, model_name): - """Create a folder with the current date and time""" +def get_experiment_folder_path(root_path, model_name): + """Get an experiment folder path with the current date and time""" date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") commit_hash = get_commit_hash() output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) - os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) return output_folder def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" - - checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") + fs = fsspec.get_mapper(experiment_path).fs + checkpoint_files = fs.glob(experiment_path + "/*.pth.tar") if not checkpoint_files: - if os.path.exists(experiment_path): - shutil.rmtree(experiment_path, ignore_errors=True) + if fs.exists(experiment_path): + fs.rm(experiment_path, recursive=True) print(" ! Run is removed from {}".format(experiment_path)) else: print(" ! Run is kept in {}".format(experiment_path)) diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 871cff6c..f634b023 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,9 +1,11 @@ import datetime -import glob +import json import os import pickle as pickle_tts -from shutil import copyfile +import shutil +from typing import Any +import fsspec import torch from coqpit import Coqpit @@ -24,7 +26,7 @@ class AttrDict(dict): self.__dict__ = self -def copy_model_files(config, out_path, new_fields): +def copy_model_files(config: Coqpit, out_path, new_fields): """Copy config.json and other model files to training folder and add new fields. @@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields): copy_config_path = os.path.join(out_path, "config.json") # add extra information fields config.update(new_fields, allow_new=True) - config.save_json(copy_config_path) + # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. + with fsspec.open(copy_config_path, "w", encoding="utf8") as f: + json.dump(config.to_dict(), f, indent=4) + # copy model stats file if available if config.audio.stats_path is not None: copy_stats_path = os.path.join(out_path, "scale_stats.npy") - if not os.path.exists(copy_stats_path): - copyfile( - config.audio.stats_path, - copy_stats_path, - ) + filesystem = fsspec.get_mapper(copy_stats_path).fs + if not filesystem.exists(copy_stats_path): + with fsspec.open(config.audio.stats_path, "rb") as source_file: + with fsspec.open(copy_stats_path, "wb") as target_file: + shutil.copyfileobj(source_file, target_file) + + +def load_fsspec(path: str, **kwargs) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + with fsspec.open(path, "rb") as f: + return torch.load(f, **kwargs) def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin try: - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) model.load_state_dict(state["model"]) if use_cuda: model.cuda() @@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli return model, state +def save_fsspec(state: Any, path: str, **kwargs): + """Like torch.save but can save to other locations (e.g. s3:// , gs://). + + Args: + state: State object to save + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.save. + """ + with fsspec.open(path, "wb") as f: + torch.save(state, f, **kwargs) + + def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): if hasattr(model, "module"): model_state = model.module.state_dict() @@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) - torch.save(state, output_path) + save_fsspec(state, output_path) def save_checkpoint( @@ -147,18 +178,16 @@ def save_best_model( model_loss=current_loss, **kwargs, ) + fs = fsspec.get_mapper(out_path).fs # only delete previous if current is saved successfully if not keep_all_best or (current_step < keep_after): - model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) + model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar")) for model_name in model_names: - if os.path.basename(model_name) == best_model_name: - continue - os.remove(model_name) - # create symlink to best model for convinience - link_name = "best_model.pth.tar" - link_path = os.path.join(out_path, link_name) - if os.path.islink(link_path) or os.path.isfile(link_path): - os.remove(link_path) - os.symlink(best_model_name, os.path.join(out_path, link_name)) + if os.path.basename(model_name) != best_model_name: + fs.rm(model_name) + # create a shortcut which always points to the currently best model + shortcut_name = "best_model.pth.tar" + shortcut_path = os.path.join(out_path, shortcut_name) + fs.copy(checkpoint_path, shortcut_path) best_loss = current_loss return best_loss diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 39176155..f203c533 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss @@ -222,7 +223,7 @@ class GAN(BaseVocoder): checkpoint_path (str): Checkpoint file path. eval (bool, optional): If true, load the model for inference. If falseDefaults to False. """ - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) # band-aid for older than v0.0.15 GAN models if "model_disc" in state: self.model_g.load_checkpoint(config, checkpoint_path, eval) diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index f606c649..2260b781 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import remove_weight_norm, weight_norm +from TTS.utils.io import load_fsspec + LRELU_SLOPE = 0.1 @@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index dabb4baa..e60baa9d 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -2,6 +2,7 @@ import torch from torch import nn from torch.nn.utils import weight_norm +from TTS.utils.io import load_fsspec from TTS.vocoder.layers.melgan import ResidualStack @@ -86,7 +87,7 @@ class MelganGenerator(nn.Module): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 788856cc..b8e78d03 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -3,6 +3,7 @@ import math import numpy as np import torch +from TTS.utils.io import load_fsspec from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample @@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index d2983be2..5dc878d7 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock @@ -220,7 +221,7 @@ class Wavegrad(BaseModel): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index c2e47120..8a968019 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.models.base_vocoder import BaseVocoder @@ -545,7 +546,7 @@ class Wavernn(BaseVocoder): def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py index 7e236db2..3de8adab 100644 --- a/TTS/vocoder/tf/utils/io.py +++ b/TTS/vocoder/tf/utils/io.py @@ -1,6 +1,7 @@ import datetime import pickle +import fsspec import tensorflow as tf @@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs): "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) - pickle.dump(state, open(output_path, "wb")) + with fsspec.open(output_path, "wb") as f: + pickle.dump(state, f) def load_checkpoint(model, checkpoint_path): """Load TF Vocoder model""" - checkpoint = pickle.load(open(checkpoint_path, "rb")) + with fsspec.open(checkpoint_path, "rb") as f: + checkpoint = pickle.load(f) chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py index e0c630b9..876739fd 100644 --- a/TTS/vocoder/tf/utils/tflite.py +++ b/TTS/vocoder/tf/utils/tflite.py @@ -1,3 +1,4 @@ +import fsspec import tensorflow as tf @@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, "wb") as f: + with fsspec.open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/requirements.txt b/requirements.txt index d5624c3b..b92947a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ mecab-python3==1.0.3 unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 +fsspec>=2021.04.0 diff --git a/tests/tts_tests/test_tacotron2_train_fsspec_path.py b/tests/tts_tests/test_tacotron2_train_fsspec_path.py new file mode 100644 index 00000000..9e4ee102 --- /dev/null +++ b/tests/tts_tests/test_tacotron2_train_fsspec_path.py @@ -0,0 +1,55 @@ +import glob +import os +import shutil + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs import Tacotron2Config + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +config = Tacotron2Config( + r=5, + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"), + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + test_sentences=[ + "Be a voice, not an echo.", + ], + print_eval=True, + max_decoder_steps=50, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} " + f"--coqpit.output_path file://{output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.test_delay_epochs 0 " +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} " +) +run_cli(command_train) +shutil.rmtree(continue_path)