Allow saving / loading checkpoints from cloud paths (#683)

* Allow saving / loading checkpoints from cloud paths

Allows saving and loading checkpoints directly from cloud paths like
Amazon S3 (s3://) and Google Cloud Storage (gs://) by using fsspec.

Note: The user will have to install the relevant dependency for each
protocol. Otherwise fsspec will fail and specify which dependency is
missing.

* Append suffix _fsspec to save/load function names

* Add a lower bound to the fsspec dependency

Skips the 0 major version.

* Add missing changes from refactor

* Use fsspec for remaining artifacts

* Add test case with path requiring fsspec

* Avoid writing logs to file unless output_path is local

* Document the possibility of using paths supported by fsspec

* Fix style and lint

* Add missing lint fixes

* Add type annotations to new functions

* Use Coqpit method for converting config to dict

* Fix type annotation in semi-new function

* Add return type for load_fsspec

* Fix bug where fs not always created

* Restore the experiment removal functionality
This commit is contained in:
Agrin Hilmkil 2021-08-05 14:11:26 +02:00 committed by Eren Gölge
parent 181177a990
commit ced4cfdbbf
31 changed files with 218 additions and 94 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
import tensorflow as tf
import torch
from TTS.utils.io import load_config
from TTS.utils.io import load_config, load_fsspec
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf,
convert_tf_name,
@ -33,7 +33,7 @@ num_speakers = 0
# init torch model
model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()

View File

@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
from TTS.utils.io import load_config, load_fsspec
sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
@ -32,7 +32,7 @@ num_speakers = 0
# init torch model
model = setup_model(c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)

View File

@ -16,6 +16,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec
use_cuda = torch.cuda.is_available()
@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c)
# restore model
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
if use_cuda:

View File

@ -17,6 +17,7 @@ from TTS.trainer import init_training
from TTS.tts.datasets import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, check_update
@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name
raise Exception("The %s not is a loss supported" % c.loss)
if args.restore_path:
checkpoint = torch.load(args.restore_path)
checkpoint = load_fsspec(args.restore_path)
try:
model.load_state_dict(checkpoint["model"])

View File

@ -3,6 +3,7 @@ import os
import re
from typing import Dict
import fsspec
import yaml
from coqpit import Coqpit
@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path):
"""for backward compat."""
# fallback to json
with open(json_path, "r", encoding="utf-8") as f:
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r"\\\n", "", input_str)
@ -76,13 +77,12 @@ def load_config(config_path: str) -> None:
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding="utf-8") as f:
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
elif ext == ".json":
try:
with open(config_path, "r", encoding="utf-8") as f:
input_str = f.read()
data = json.loads(input_str)
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.decoder.JSONDecodeError:
# backwards compat.
data = read_json_with_comments(config_path)

View File

@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit):
num_eval_loader_workers (int):
Number of workers for evaluation time dataloader.
output_path (str):
Path for training output folder. The nonexist part of the given path is created automatically.
All training outputs are saved there.
Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
"""
model: str = None

View File

@ -2,6 +2,8 @@ import numpy as np
import torch
from torch import nn
from TTS.utils.io import load_fsspec
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module):
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()

View File

@ -2,6 +2,8 @@ import numpy as np
import torch
import torch.nn as nn
from TTS.utils.io import load_fsspec
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module):
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()

View File

@ -6,11 +6,11 @@ import re
from multiprocessing import Manager
import numpy as np
import torch
from scipy import signal
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class Storage(object):
@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,7 +1,7 @@
import datetime
import os
import torch
from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import glob
import importlib
import logging
import os
@ -12,7 +11,9 @@ import traceback
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
import fsspec
import torch
from coqpit import Coqpit
from torch import nn
@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (
KeepAverage,
count_parameters,
create_experiment_folder,
get_experiment_folder_path,
get_git_branch,
remove_experiment_folder,
set_init_dict,
to_cuda,
)
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@ -173,7 +174,6 @@ class Trainer:
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None
self.keep_avg_eval = None
@ -309,7 +309,7 @@ class Trainer:
return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
checkpoint = load_fsspec(restore_path)
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
@ -776,7 +776,7 @@ class Trainer:
"""🏃 train -> evaluate -> test for the number of epochs."""
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
self.total_steps_done = self.restore_step
@ -834,9 +834,16 @@ class Trainer:
@staticmethod
def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
handlers = [logging.StreamHandler()]
# Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
@staticmethod
def _is_apex_available() -> bool:
@ -926,22 +933,27 @@ def init_arguments():
return parser
def get_last_checkpoint(path):
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Args:
path (list): Path to files to be compared.
path: Path to files to be compared.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
last_checkpoint (str): Last checkpoint filename.
Path to the last checkpoint
Path to best checkpoint
"""
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
@ -963,7 +975,7 @@ def get_last_checkpoint(path):
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = torch.load(last_model)["step"]
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
@ -1030,12 +1042,11 @@ def process_args(args, config=None):
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name)
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training
tb_logger = None
if args.rank == 0:
os.makedirs(audio_path, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
@ -1047,8 +1058,6 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
os.chmod(audio_path, 0o775)
os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)

View File

@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass
@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if "r" in state:
self.decoder.set_r(state["r"])

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
class GlowTTS(BaseTTS):
@ -382,7 +383,7 @@ class GlowTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@dataclass
@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -2,6 +2,7 @@ import datetime
import importlib
import pickle
import fsspec
import numpy as np
import tensorflow as tf
@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,6 +1,7 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf
@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, "wb") as f:
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model

View File

@ -3,6 +3,7 @@ import os
import random
from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
@ -84,12 +85,12 @@ class SpeakerManager:
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
@ -294,9 +295,10 @@ def _set_file_path(path):
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
path_continue = os.path.join(path, "speakers.json")
if os.path.exists(path_restore):
fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore
if os.path.exists(path_continue):
if fs.exists(path_continue):
return path_continue
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
json_file = out_path
else:
json_file = _set_file_path(out_path)
with open(json_file) as f:
with fsspec.open(json_file, "r") as f:
return json.load(f)
@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = _set_file_path(out_path)
with open(speakers_json_path, "w") as f:
with fsspec.open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)

View File

@ -1,15 +1,14 @@
# -*- coding: utf-8 -*-
import datetime
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Dict
import fsspec
import torch
@ -58,23 +57,22 @@ def get_commit_hash():
return commit
def create_experiment_folder(root_path, model_name):
"""Create a folder with the current date and time"""
def get_experiment_folder_path(root_path, model_name):
"""Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
if not checkpoint_files:
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path, ignore_errors=True)
if fs.exists(experiment_path):
fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))

View File

@ -1,9 +1,11 @@
import datetime
import glob
import json
import os
import pickle as pickle_tts
from shutil import copyfile
import shutil
from typing import Any
import fsspec
import torch
from coqpit import Coqpit
@ -24,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self
def copy_model_files(config, out_path, new_fields):
def copy_model_files(config: Coqpit, out_path, new_fields):
"""Copy config.json and other model files to training folder and add
new fields.
@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
config.update(new_fields, allow_new=True)
config.save_json(copy_config_path)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path):
copyfile(
config.audio.stats_path,
copy_stats_path,
)
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
torch.save(state, output_path)
save_fsspec(state, output_path)
def save_checkpoint(
@ -147,18 +178,16 @@ def save_best_model(
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names:
if os.path.basename(model_name) == best_model_name:
continue
os.remove(model_name)
# create symlink to best model for convinience
link_name = "best_model.pth.tar"
link_path = os.path.join(out_path, link_name)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth.tar"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@ -222,7 +223,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)

View File

@ -5,6 +5,8 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -2,6 +2,7 @@ import torch
from torch import nn
from torch.nn.utils import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -3,6 +3,7 @@ import math
import numpy as np
import torch
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -1,6 +1,7 @@
import datetime
import pickle
import fsspec
import tensorflow as tf
@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
pickle.dump(state, open(output_path, "wb"))
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
"""Load TF Vocoder model"""
checkpoint = pickle.load(open(checkpoint_path, "rb"))
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:

View File

@ -1,3 +1,4 @@
import fsspec
import tensorflow as tf
@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, "wb") as f:
with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model

View File

@ -24,3 +24,4 @@ mecab-python3==1.0.3
unidic-lite==1.0.8
# gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
fsspec>=2021.04.0

View File

@ -0,0 +1,55 @@
import glob
import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = Tacotron2Config(
r=5,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
test_sentences=[
"Be a voice, not an echo.",
],
print_eval=True,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
f"--coqpit.output_path file://{output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
)
run_cli(command_train)
shutil.rmtree(continue_path)