using a unified console_logger for tts and vocoder modules

This commit is contained in:
erogol 2020-07-28 14:03:46 +02:00
parent 84158c5e47
commit ade5fc2675
4 changed files with 11 additions and 108 deletions

View File

@ -11,11 +11,9 @@ import traceback
import numpy as np
import torch
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.console_logger import ConsoleLogger
from TTS.tts.utils.distribute import (DistributedSampler,
apply_gradient_allreduce,
init_distributed, reduce_tensor)
@ -28,6 +26,7 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)

View File

@ -4,13 +4,12 @@ import os
import sys
import time
import traceback
from inspect import signature
import torch
from torch.utils.data import DataLoader
from inspect import signature
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
@ -23,12 +22,10 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce,
# init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
from TTS.vocoder.utils.console_logger import ConsoleLogger
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
setup_discriminator,
setup_generator)
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -238,10 +235,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr_G": current_lr_G,
"current_lr_D": current_lr_D
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
step_time, loader_time, current_lr_G,
current_lr_D, loss_dict,
keep_avg.avg_values)
log_dict, loss_dict, keep_avg.avg_values)
# plot step stats
if global_step % 10 == 0:

View File

@ -1,97 +0,0 @@
import datetime
from TTS.utils.io import AttrDict
tcolors = AttrDict({
'OKBLUE': '\033[94m',
'HEADER': '\033[95m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m',
'BOLD': '\033[1m',
'UNDERLINE': '\033[4m'
})
class ConsoleLogger():
# TODO: merge this with TTS ConsoleLogger
def __init__(self):
# use these to compare values between iterations
self.old_train_loss_dict = None
self.old_epoch_loss_dict = None
self.old_eval_loss_dict = None
# pylint: disable=no-self-use
def get_time(self):
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch):
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
epoch, max_epoch, tcolors.ENDC),
flush=True)
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, log_dict,
step_time, loader_time, lrG, lrD,
loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lrG: {lrG}\n{indent}lrD: {lrD}"
print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict):
indent = " | > "
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
print(log_text, flush=True)
def print_eval_start(self):
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
def print_eval_step(self, step, loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
print(log_text, flush=True)
def print_epoch_end(self, epoch, avg_loss_dict):
indent = " | > "
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
tcolors.BOLD, tcolors.ENDC)
for key, value in avg_loss_dict.items():
# print the avg value if given
color = ''
sign = '+'
diff = 0
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
diff = value - self.old_eval_loss_dict[key]
if diff < 0:
color = tcolors.OKGREEN
sign = ''
elif diff > 0:
color = tcolors.FAIL
sign = '+'
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True)