From e53616078af4074b67f6f2d8e5182d43d3679541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 14:26:11 +0200 Subject: [PATCH] Fixup `utils` for the trainer --- TTS/utils/generic_utils.py | 14 ++++++-------- TTS/utils/logging/tensorboard_logger.py | 2 ++ TTS/utils/manage.py | 2 +- TTS/utils/radam.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 67cd0bf5..e7c57529 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -16,9 +16,10 @@ import torch def to_cuda(x: torch.Tensor) -> torch.Tensor: if x is None: return None - x = x.contiguous() - if torch.cuda.is_available(): - x = x.cuda(non_blocking=True) + if torch.is_tensor(x): + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) return x @@ -57,13 +58,10 @@ def get_commit_hash(): return commit -def create_experiment_folder(root_path, model_name, debug): +def create_experiment_folder(root_path, model_name): """Create a folder with the current date and time""" date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") - if debug: - commit_hash = "debug" - else: - commit_hash = get_commit_hash() + commit_hash = get_commit_hash() output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) diff --git a/TTS/utils/logging/tensorboard_logger.py b/TTS/utils/logging/tensorboard_logger.py index 657deb5b..3d7ea1e6 100644 --- a/TTS/utils/logging/tensorboard_logger.py +++ b/TTS/utils/logging/tensorboard_logger.py @@ -34,6 +34,8 @@ class TensorboardLogger(object): def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): for key, value in audios.items(): + if value.dtype == "float16": + value = value.astype("float32") try: self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) except RuntimeError: diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index f5165079..93497517 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -137,7 +137,7 @@ class ModelManager(object): # set scale stats path in config.json config_path = output_config_path config = load_config(config_path) - config.external_speaker_embedding_file = output_speakers_path + config.d_vector_file = output_speakers_path config.save_json(config_path) return output_model_path, output_config_path, model_item diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py index b6c86fed..73426e64 100644 --- a/TTS/utils/radam.py +++ b/TTS/utils/radam.py @@ -1,4 +1,4 @@ -# from https://github.com/LiyuanLucasLiu/RAdam +# modified from https://github.com/LiyuanLucasLiu/RAdam import math