Allow saving / loading checkpoints from cloud paths (#683)

* Allow saving / loading checkpoints from cloud paths

Allows saving and loading checkpoints directly from cloud paths like
Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec.

Note: The user will have to install the relevant dependency for each
protocol. Otherwise fsspec will fail and specify which dependency is
missing.

* Append suffix _fsspec to save/load function names

* Add a lower bound to the fsspec dependency

Skips the 0 major version.

* Add missing changes from refactor

* Use fsspec for remaining artifacts

* Add test case with path requiring fsspec

* Avoid writing logs to file unless output_path is local

* Document the possibility of using paths supported by fsspec

* Fix style and lint

* Add missing lint fixes

* Add type annotations to new functions

* Use Coqpit method for converting config to dict

* Fix type annotation in semi-new function

* Add return type for load_fsspec

* Fix bug where fs not always created

* Restore the experiment removal functionality
This commit is contained in:
Agrin Hilmkil 2021-08-05 14:11:26 +02:00 committed by Eren Gölge
parent 181177a990
commit ced4cfdbbf
31 changed files with 218 additions and 94 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from TTS.utils.io import load_config from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, compare_torch_tf,
convert_tf_name, convert_tf_name,
@ -33,7 +33,7 @@ num_speakers = 0
# init torch model # init torch model
model = setup_generator(c) model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.remove_weight_norm() model.remove_weight_norm()

View File

@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects") sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["CUDA_VISIBLE_DEVICES"] = ""
@ -32,7 +32,7 @@ num_speakers = 0
# init torch model # init torch model
model = setup_model(c) model = setup_model(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"] state_dict = checkpoint["model"]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)

View File

@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c) model = setup_model(c)
# restore model # restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu") checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
if use_cuda: if use_cuda:

View File

@ -17,6 +17,7 @@ from TTS.trainer import init_training
from TTS.tts.datasets import load_meta_data from TTS.tts.datasets import load_meta_data
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update from TTS.utils.training import NoamLR, check_update
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
raise Exception("The %s not is a loss supported" % c.loss) raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = load_fsspec(args.restore_path)
try: try:
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])

View File

@ -3,6 +3,7 @@ import os
import re import re
from typing import Dict from typing import Dict
import fsspec
import yaml import yaml
from coqpit import Coqpit from coqpit import Coqpit
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path): def read_json_with_comments(json_path):
"""for backward compat.""" """for backward compat."""
# fallback to json # fallback to json
with open(json_path, "r", encoding="utf-8") as f: with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
# handle comments # handle comments
input_str = re.sub(r"\\\n", "", input_str) input_str = re.sub(r"\\\n", "", input_str)
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
config_dict = {} config_dict = {}
ext = os.path.splitext(config_path)[1] ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"): if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f: with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
elif ext == ".json": elif ext == ".json":
try: try:
with open(config_path, "r", encoding="utf-8") as f: with fsspec.open(config_path, "r", encoding="utf-8") as f:
input_str = f.read() data = json.load(f)
data = json.loads(input_str)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
# backwards compat. # backwards compat.
data = read_json_with_comments(config_path) data = read_json_with_comments(config_path)

View File

@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit):
num_eval_loader_workers (int): num_eval_loader_workers (int):
Number of workers for evaluation time dataloader. Number of workers for evaluation time dataloader.
output_path (str): output_path (str):
Path for training output folder. The nonexist part of the given path is created automatically. Path for training output folder, either a local file path or other
All training outputs are saved there. URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
""" """
model: str = None model: str = None

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from TTS.utils.io import load_fsspec
class LSTMWithProjection(nn.Module): class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size): def __init__(self, input_size, hidden_size, proj_size):
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
# pylint: disable=unused-argument, redefined-builtin # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if use_cuda: if use_cuda:
self.cuda() self.cuda()

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from TTS.utils.io import load_fsspec
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
return embeddings return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if use_cuda: if use_cuda:
self.cuda() self.cuda()

View File

@ -6,11 +6,11 @@ import re
from multiprocessing import Manager from multiprocessing import Manager
import numpy as np import numpy as np
import torch
from scipy import signal from scipy import signal
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class Storage(object): class Storage(object):
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
torch.save(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
bestmodel_path = "best_model.pth.tar" bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path) save_fsspec(state, bestmodel_path)
return best_loss return best_loss

View File

@ -1,7 +1,7 @@
import datetime import datetime
import os import os
import torch from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step): def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
"loss": model_loss, "loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
torch.save(state, checkpoint_path) save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
bestmodel_path = "best_model.pth.tar" bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path) save_fsspec(state, bestmodel_path)
return best_loss return best_loss

View File

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import glob
import importlib import importlib
import logging import logging
import os import os
@ -12,7 +11,9 @@ import traceback
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (
KeepAverage, KeepAverage,
count_parameters, count_parameters,
create_experiment_folder, get_experiment_folder_path,
get_git_branch, get_git_branch,
remove_experiment_folder, remove_experiment_folder,
set_init_dict, set_init_dict,
to_cuda, to_cuda,
) )
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@ -173,7 +174,6 @@ class Trainer:
self.best_loss = float("inf") self.best_loss = float("inf")
self.train_loader = None self.train_loader = None
self.eval_loader = None self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None self.keep_avg_train = None
self.keep_avg_eval = None self.keep_avg_eval = None
@ -309,7 +309,7 @@ class Trainer:
return obj return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path)) print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path) checkpoint = load_fsspec(restore_path)
try: try:
print(" > Restoring Model...") print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
@ -776,7 +776,7 @@ class Trainer:
"""🏃 train -> evaluate -> test for the number of epochs.""" """🏃 train -> evaluate -> test for the number of epochs."""
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
@ -834,9 +834,16 @@ class Trainer:
@staticmethod @staticmethod
def _setup_logger_config(log_file: str) -> None: def _setup_logger_config(log_file: str) -> None:
logging.basicConfig( handlers = [logging.StreamHandler()]
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
) # Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
@staticmethod @staticmethod
def _is_apex_available() -> bool: def _is_apex_available() -> bool:
@ -926,22 +933,27 @@ def init_arguments():
return parser return parser
def get_last_checkpoint(path): def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path. """Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`. `(checkpoint|best_model)_([0-9]+)`.
Args: Args:
path (list): Path to files to be compared. path: Path to files to be compared.
Raises: Raises:
ValueError: If no checkpoint or best_model files are found. ValueError: If no checkpoint or best_model files are found.
Returns: Returns:
last_checkpoint (str): Last checkpoint filename. Path to the last checkpoint
Path to best checkpoint
""" """
file_names = glob.glob(os.path.join(path, "*.pth.tar")) fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {} last_models = {}
last_model_nums = {} last_model_nums = {}
for key in ["checkpoint", "best_model"]: for key in ["checkpoint", "best_model"]:
@ -963,7 +975,7 @@ def get_last_checkpoint(path):
key_file_names = [fn for fn in file_names if key in fn] key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0: if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime) last_model = max(key_file_names, key=os.path.getctime)
last_model_num = torch.load(last_model)["step"] last_model_num = load_fsspec(last_model)["step"]
if last_model is not None: if last_model is not None:
last_models[key] = last_model last_models[key] = last_model
@ -1030,12 +1042,11 @@ def process_args(args, config=None):
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name) experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None tb_logger = None
if args.rank == 0: if args.rank == 0:
os.makedirs(audio_path, exist_ok=True)
new_fields = {} new_fields = {}
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
@ -1047,8 +1058,6 @@ def process_args(args, config=None):
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0) tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)

View File

@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass @dataclass
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if "r" in state: if "r" in state:
self.decoder.set_r(state["r"]) self.decoder.set_r(state["r"])

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
class GlowTTS(BaseTTS): class GlowTTS(BaseTTS):
@ -382,7 +383,7 @@ class GlowTTS(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass @dataclass
@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -2,6 +2,7 @@ import datetime
import importlib import importlib
import pickle import pickle
import fsspec
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r, "r": r,
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,6 +1,7 @@
import datetime import datetime
import pickle import pickle
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r, "r": r,
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None: if output_path is not None:
# same model binary if outputpath is provided # same model binary if outputpath is provided
with open(output_path, "wb") as f: with fsspec.open(output_path, "wb") as f:
f.write(tflite_model) f.write(tflite_model)
return None return None
return tflite_model return tflite_model

View File

@ -3,6 +3,7 @@ import os
import random import random
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -84,12 +85,12 @@ class SpeakerManager:
@staticmethod @staticmethod
def _load_json(json_file_path: str) -> Dict: def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f: with fsspec.open(json_file_path, "r") as f:
return json.load(f) return json.load(f)
@staticmethod @staticmethod
def _save_json(json_file_path: str, data: dict) -> None: def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f: with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
@property @property
@ -294,9 +295,10 @@ def _set_file_path(path):
Intended to band aid the different paths returned in restored and continued training.""" Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "speakers.json") path_restore = os.path.join(os.path.dirname(path), "speakers.json")
path_continue = os.path.join(path, "speakers.json") path_continue = os.path.join(path, "speakers.json")
if os.path.exists(path_restore): fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore return path_restore
if os.path.exists(path_continue): if fs.exists(path_continue):
return path_continue return path_continue
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
json_file = out_path json_file = out_path
else: else:
json_file = _set_file_path(out_path) json_file = _set_file_path(out_path)
with open(json_file) as f: with fsspec.open(json_file, "r") as f:
return json.load(f) return json.load(f)
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present.""" """Saves speaker mapping if not yet present."""
if out_path is not None: if out_path is not None:
speakers_json_path = _set_file_path(out_path) speakers_json_path = _set_file_path(out_path)
with open(speakers_json_path, "w") as f: with fsspec.open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4) json.dump(speaker_mapping, f, indent=4)

View File

@ -1,15 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime
import glob
import importlib import importlib
import os import os
import re import re
import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import fsspec
import torch import torch
@ -58,23 +57,22 @@ def get_commit_hash():
return commit return commit
def create_experiment_folder(root_path, model_name): def get_experiment_folder_path(root_path, model_name):
"""Create a folder with the current date and time""" """Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder)) print(" > Experiment folder: {}".format(output_folder))
return output_folder return output_folder
def remove_experiment_folder(experiment_path): def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder""" """Check folder if there is a checkpoint, otherwise remove the folder"""
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
if not checkpoint_files: if not checkpoint_files:
if os.path.exists(experiment_path): if fs.exists(experiment_path):
shutil.rmtree(experiment_path, ignore_errors=True) fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path)) print(" ! Run is removed from {}".format(experiment_path))
else: else:
print(" ! Run is kept in {}".format(experiment_path)) print(" ! Run is kept in {}".format(experiment_path))

View File

@ -1,9 +1,11 @@
import datetime import datetime
import glob import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
from shutil import copyfile import shutil
from typing import Any
import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -24,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
def copy_model_files(config, out_path, new_fields): def copy_model_files(config: Coqpit, out_path, new_fields):
"""Copy config.json and other model files to training folder and add """Copy config.json and other model files to training folder and add
new fields. new fields.
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
copy_config_path = os.path.join(out_path, "config.json") copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields # add extra information fields
config.update(new_fields, allow_new=True) config.update(new_fields, allow_new=True)
config.save_json(copy_config_path) # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available # copy model stats file if available
if config.audio.stats_path is not None: if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy") copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path): filesystem = fsspec.get_mapper(copy_stats_path).fs
copyfile( if not filesystem.exists(copy_stats_path):
config.audio.stats_path, with fsspec.open(config.audio.stats_path, "rb") as source_file:
copy_stats_path, with fsspec.open(copy_stats_path, "wb") as target_file:
) shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try: try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError: except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"]) model.load_state_dict(state["model"])
if use_cuda: if use_cuda:
model.cuda() model.cuda()
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
return model, state return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"): if hasattr(model, "module"):
model_state = model.module.state_dict() model_state = model.module.state_dict()
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
state.update(kwargs) state.update(kwargs)
torch.save(state, output_path) save_fsspec(state, output_path)
def save_checkpoint( def save_checkpoint(
@ -147,18 +178,16 @@ def save_best_model(
model_loss=current_loss, model_loss=current_loss,
**kwargs, **kwargs,
) )
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully # only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after): if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names: for model_name in model_names:
if os.path.basename(model_name) == best_model_name: if os.path.basename(model_name) != best_model_name:
continue fs.rm(model_name)
os.remove(model_name) # create a shortcut which always points to the currently best model
# create symlink to best model for convinience shortcut_name = "best_model.pth.tar"
link_name = "best_model.pth.tar" shortcut_path = os.path.join(out_path, shortcut_name)
link_path = os.path.join(out_path, link_name) fs.copy(checkpoint_path, shortcut_path)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
best_loss = current_loss best_loss = current_loss
return best_loss return best_loss

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path. checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False. eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
""" """
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models # band-aid for older than v0.0.15 GAN models
if "model_disc" in state: if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval) self.model_g.load_checkpoint(config, checkpoint_path, eval)

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack from TTS.vocoder.layers.melgan import ResidualStack
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -3,6 +3,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample from TTS.vocoder.layers.upsample import ConvUpsample
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.models.base_vocoder import BaseVocoder
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -1,6 +1,7 @@
import datetime import datetime
import pickle import pickle
import fsspec
import tensorflow as tf import tensorflow as tf
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
"date": datetime.date.today().strftime("%B %d, %Y"), "date": datetime.date.today().strftime("%B %d, %Y"),
} }
state.update(kwargs) state.update(kwargs)
pickle.dump(state, open(output_path, "wb")) with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
"""Load TF Vocoder model""" """Load TF Vocoder model"""
checkpoint = pickle.load(open(checkpoint_path, "rb")) with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf import tensorflow as tf
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None: if output_path is not None:
# same model binary if outputpath is provided # same model binary if outputpath is provided
with open(output_path, "wb") as f: with fsspec.open(output_path, "wb") as f:
f.write(tflite_model) f.write(tflite_model)
return None return None
return tflite_model return tflite_model

View File

@ -24,3 +24,4 @@ mecab-python3==1.0.3
unidic-lite==1.0.8 unidic-lite==1.0.8
# gruut+supported langs # gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
fsspec>=2021.04.0

View File

@ -0,0 +1,55 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = Tacotron2Config(
r=5,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
test_sentences=[
"Be a voice, not an echo.",
],
print_eval=True,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
f"--coqpit.output_path file://{output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
)
run_cli(command_train)
shutil.rmtree(continue_path)