mirror of https://github.com/coqui-ai/TTS.git
using a unified console_logger for tts and vocoder modules
This commit is contained in:
parent
84158c5e47
commit
ade5fc2675
|
@ -11,11 +11,9 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.utils.console_logger import ConsoleLogger
|
||||
from TTS.tts.utils.distribute import (DistributedSampler,
|
||||
apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
|
@ -28,6 +26,7 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
|
|
|
@ -4,13 +4,12 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from inspect import signature
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
|
@ -23,12 +22,10 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
|||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
# init_distributed, reduce_tensor)
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
|
||||
from TTS.vocoder.utils.console_logger import ConsoleLogger
|
||||
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
||||
setup_discriminator,
|
||||
setup_generator)
|
||||
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
@ -238,10 +235,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
'step_time': [step_time, 2],
|
||||
'loader_time': [loader_time, 4],
|
||||
"current_lr_G": current_lr_G,
|
||||
"current_lr_D": current_lr_D
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
step_time, loader_time, current_lr_G,
|
||||
current_lr_D, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
import datetime
|
||||
from TTS.utils.io import AttrDict
|
||||
|
||||
|
||||
tcolors = AttrDict({
|
||||
'OKBLUE': '\033[94m',
|
||||
'HEADER': '\033[95m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m',
|
||||
'BOLD': '\033[1m',
|
||||
'UNDERLINE': '\033[4m'
|
||||
})
|
||||
|
||||
|
||||
class ConsoleLogger():
|
||||
# TODO: merge this with TTS ConsoleLogger
|
||||
def __init__(self):
|
||||
# use these to compare values between iterations
|
||||
self.old_train_loss_dict = None
|
||||
self.old_epoch_loss_dict = None
|
||||
self.old_eval_loss_dict = None
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def get_time(self):
|
||||
now = datetime.datetime.now()
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def print_epoch_start(self, epoch, max_epoch):
|
||||
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
|
||||
epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True)
|
||||
|
||||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict,
|
||||
step_time, loader_time, lrG, lrD,
|
||||
loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
||||
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lrG: {lrG}\n{indent}lrD: {lrD}"
|
||||
print(log_text, flush=True)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
||||
print_dict):
|
||||
indent = " | > "
|
||||
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
|
||||
for key, value in print_dict.items():
|
||||
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
||||
print(log_text, flush=True)
|
||||
|
||||
def print_eval_start(self):
|
||||
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
|
||||
|
||||
def print_eval_step(self, step, loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
print(log_text, flush=True)
|
||||
|
||||
def print_epoch_end(self, epoch, avg_loss_dict):
|
||||
indent = " | > "
|
||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
|
||||
tcolors.BOLD, tcolors.ENDC)
|
||||
for key, value in avg_loss_dict.items():
|
||||
# print the avg value if given
|
||||
color = ''
|
||||
sign = '+'
|
||||
diff = 0
|
||||
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
|
||||
diff = value - self.old_eval_loss_dict[key]
|
||||
if diff < 0:
|
||||
color = tcolors.OKGREEN
|
||||
sign = ''
|
||||
elif diff > 0:
|
||||
color = tcolors.FAIL
|
||||
sign = '+'
|
||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||
self.old_eval_loss_dict = avg_loss_dict
|
||||
print(log_text, flush=True)
|
Loading…
Reference in New Issue