Fixup `utils` for the trainer

This commit is contained in:
Eren Gölge 2021-06-18 14:26:11 +02:00
parent b4bb567e04
commit 8182f5168f
4 changed files with 10 additions and 10 deletions

View File

@ -16,6 +16,7 @@ import torch
def to_cuda(x: torch.Tensor) -> torch.Tensor: def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None: if x is None:
return None return None
if torch.is_tensor(x):
x = x.contiguous() x = x.contiguous()
if torch.cuda.is_available(): if torch.cuda.is_available():
x = x.cuda(non_blocking=True) x = x.cuda(non_blocking=True)
@ -57,12 +58,9 @@ def get_commit_hash():
return commit return commit
def create_experiment_folder(root_path, model_name, debug): def create_experiment_folder(root_path, model_name):
"""Create a folder with the current date and time""" """Create a folder with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
if debug:
commit_hash = "debug"
else:
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)

View File

@ -34,6 +34,8 @@ class TensorboardLogger(object):
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
for key, value in audios.items(): for key, value in audios.items():
if value.dtype == "float16":
value = value.astype("float32")
try: try:
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
except RuntimeError: except RuntimeError:

View File

@ -137,7 +137,7 @@ class ModelManager(object):
# set scale stats path in config.json # set scale stats path in config.json
config_path = output_config_path config_path = output_config_path
config = load_config(config_path) config = load_config(config_path)
config.external_speaker_embedding_file = output_speakers_path config.d_vector_file = output_speakers_path
config.save_json(config_path) config.save_json(config_path)
return output_model_path, output_config_path, model_item return output_model_path, output_config_path, model_item

View File

@ -1,4 +1,4 @@
# from https://github.com/LiyuanLucasLiu/RAdam # modified from https://github.com/LiyuanLucasLiu/RAdam
import math import math