mirror of https://github.com/coqui-ai/TTS.git
Fixup `utils` for the trainer
This commit is contained in:
parent
b4bb567e04
commit
8182f5168f
|
@ -16,6 +16,7 @@ import torch
|
||||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
|
if torch.is_tensor(x):
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
x = x.cuda(non_blocking=True)
|
x = x.cuda(non_blocking=True)
|
||||||
|
@ -57,12 +58,9 @@ def get_commit_hash():
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path, model_name, debug):
|
def create_experiment_folder(root_path, model_name):
|
||||||
"""Create a folder with the current date and time"""
|
"""Create a folder with the current date and time"""
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
if debug:
|
|
||||||
commit_hash = "debug"
|
|
||||||
else:
|
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
|
@ -34,6 +34,8 @@ class TensorboardLogger(object):
|
||||||
|
|
||||||
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
|
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
|
||||||
for key, value in audios.items():
|
for key, value in audios.items():
|
||||||
|
if value.dtype == "float16":
|
||||||
|
value = value.astype("float32")
|
||||||
try:
|
try:
|
||||||
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
|
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|
|
@ -137,7 +137,7 @@ class ModelManager(object):
|
||||||
# set scale stats path in config.json
|
# set scale stats path in config.json
|
||||||
config_path = output_config_path
|
config_path = output_config_path
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
config.external_speaker_embedding_file = output_speakers_path
|
config.d_vector_file = output_speakers_path
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
return output_model_path, output_config_path, model_item
|
return output_model_path, output_config_path, model_item
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# from https://github.com/LiyuanLucasLiu/RAdam
|
# modified from https://github.com/LiyuanLucasLiu/RAdam
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue