mirror of https://github.com/coqui-ai/TTS.git
revert logging.info to print statements for trainer
This commit is contained in:
parent
fd6afe5ae5
commit
c7ff175592
|
@ -150,7 +150,7 @@ class TrainerTTS:
|
||||||
|
|
||||||
# count model size
|
# count model size
|
||||||
num_params = count_parameters(self.model)
|
num_params = count_parameters(self.model)
|
||||||
logging.info("\n > Model has {} parameters".format(num_params))
|
print("\n > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||||
|
@ -186,7 +186,6 @@ class TrainerTTS:
|
||||||
out_path: str = "",
|
out_path: str = "",
|
||||||
data_train: List = []) -> SpeakerManager:
|
data_train: List = []) -> SpeakerManager:
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
if config.use_speaker_embedding:
|
|
||||||
if restore_path:
|
if restore_path:
|
||||||
speakers_file = os.path.join(os.path.dirname(restore_path),
|
speakers_file = os.path.join(os.path.dirname(restore_path),
|
||||||
"speaker.json")
|
"speaker.json")
|
||||||
|
@ -196,16 +195,6 @@ class TrainerTTS:
|
||||||
)
|
)
|
||||||
speakers_file = config.external_speaker_embedding_file
|
speakers_file = config.external_speaker_embedding_file
|
||||||
|
|
||||||
if config.use_external_speaker_embedding_file:
|
|
||||||
speaker_manager.load_x_vectors_file(speakers_file)
|
|
||||||
else:
|
|
||||||
speaker_manager.load_ids_file(speakers_file)
|
|
||||||
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
|
||||||
speaker_manager.load_x_vectors_file(
|
|
||||||
config.external_speaker_embedding_file)
|
|
||||||
else:
|
|
||||||
speaker_manager.parse_speakers_from_items(data_train)
|
|
||||||
file_path = os.path.join(out_path, "speakers.json")
|
|
||||||
speaker_manager.save_ids_file(file_path)
|
speaker_manager.save_ids_file(file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
@ -238,15 +227,15 @@ class TrainerTTS:
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
checkpoint = torch.load(restore_path)
|
checkpoint = torch.load(restore_path)
|
||||||
try:
|
try:
|
||||||
logging.info(" > Restoring Model...")
|
print(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
logging.info(" > Restoring Optimizer...")
|
print(" > Restoring Optimizer...")
|
||||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
if "scaler" in checkpoint and config.mixed_precision:
|
if "scaler" in checkpoint and config.mixed_precision:
|
||||||
logging.info(" > Restoring AMP Scaler...")
|
print(" > Restoring AMP Scaler...")
|
||||||
scaler.load_state_dict(checkpoint["scaler"])
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
except (KeyError, RuntimeError):
|
except (KeyError, RuntimeError):
|
||||||
logging.info(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
|
Loading…
Reference in New Issue