using a unified console_logger for tts and vocoder modules

This commit is contained in:
erogol 2020-07-28 14:03:46 +02:00
parent 84158c5e47
commit ade5fc2675
4 changed files with 11 additions and 108 deletions

View File

@ -11,11 +11,9 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.console_logger import ConsoleLogger
from TTS.tts.utils.distribute import (DistributedSampler, from TTS.tts.utils.distribute import (DistributedSampler,
apply_gradient_allreduce, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
@ -28,6 +26,7 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch, create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)

View File

@ -4,13 +4,12 @@ import os
import sys import sys
import time import time
import traceback import traceback
from inspect import signature
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from inspect import signature
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch, create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
@ -23,12 +22,10 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce, # from distribute import (DistributedSampler, apply_gradient_allreduce,
# init_distributed, reduce_tensor) # init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
from TTS.vocoder.utils.console_logger import ConsoleLogger
from TTS.vocoder.utils.generic_utils import (check_config, plot_results, from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
setup_discriminator, setup_discriminator,
setup_generator) setup_generator)
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True) use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -238,10 +235,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# print training stats # print training stats
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr_G": current_lr_G,
"current_lr_D": current_lr_D
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step, c_logger.print_train_step(batch_n_iter, num_iter, global_step,
step_time, loader_time, current_lr_G, log_dict, loss_dict, keep_avg.avg_values)
current_lr_D, loss_dict,
keep_avg.avg_values)
# plot step stats # plot step stats
if global_step % 10 == 0: if global_step % 10 == 0:

View File

@ -1,97 +0,0 @@
import datetime
from TTS.utils.io import AttrDict
tcolors = AttrDict({
'OKBLUE': '\033[94m',
'HEADER': '\033[95m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m',
'BOLD': '\033[1m',
'UNDERLINE': '\033[4m'
})
class ConsoleLogger():
# TODO: merge this with TTS ConsoleLogger
def __init__(self):
# use these to compare values between iterations
self.old_train_loss_dict = None
self.old_epoch_loss_dict = None
self.old_eval_loss_dict = None
# pylint: disable=no-self-use
def get_time(self):
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch):
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
epoch, max_epoch, tcolors.ENDC),
flush=True)
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, log_dict,
step_time, loader_time, lrG, lrD,
loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lrG: {lrG}\n{indent}lrD: {lrD}"
print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict):
indent = " | > "
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
print(log_text, flush=True)
def print_eval_start(self):
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
def print_eval_step(self, step, loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
print(log_text, flush=True)
def print_epoch_end(self, epoch, avg_loss_dict):
indent = " | > "
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
tcolors.BOLD, tcolors.ENDC)
for key, value in avg_loss_dict.items():
# print the avg value if given
color = ''
sign = '+'
diff = 0
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
diff = value - self.old_eval_loss_dict[key]
if diff < 0:
color = tcolors.OKGREEN
sign = ''
elif diff > 0:
color = tcolors.FAIL
sign = '+'
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True)