mirror of https://github.com/coqui-ai/TTS.git
using a unified console_logger for tts and vocoder modules
This commit is contained in:
parent
84158c5e47
commit
ade5fc2675
|
@ -11,11 +11,9 @@ import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.tts.utils.distribute import (DistributedSampler,
|
from TTS.tts.utils.distribute import (DistributedSampler,
|
||||||
apply_gradient_allreduce,
|
apply_gradient_allreduce,
|
||||||
init_distributed, reduce_tensor)
|
init_distributed, reduce_tensor)
|
||||||
|
@ -28,6 +26,7 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
|
|
|
@ -4,13 +4,12 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from inspect import signature
|
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
|
@ -23,12 +22,10 @@ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
# init_distributed, reduce_tensor)
|
# init_distributed, reduce_tensor)
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
|
|
||||||
from TTS.vocoder.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
||||||
setup_discriminator,
|
setup_discriminator,
|
||||||
setup_generator)
|
setup_generator)
|
||||||
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
|
||||||
|
@ -238,10 +235,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
|
|
||||||
# print training stats
|
# print training stats
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
|
log_dict = {
|
||||||
|
'step_time': [step_time, 2],
|
||||||
|
'loader_time': [loader_time, 4],
|
||||||
|
"current_lr_G": current_lr_G,
|
||||||
|
"current_lr_D": current_lr_D
|
||||||
|
}
|
||||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||||
step_time, loader_time, current_lr_G,
|
log_dict, loss_dict, keep_avg.avg_values)
|
||||||
current_lr_D, loss_dict,
|
|
||||||
keep_avg.avg_values)
|
|
||||||
|
|
||||||
# plot step stats
|
# plot step stats
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
|
|
|
@ -1,97 +0,0 @@
|
||||||
import datetime
|
|
||||||
from TTS.utils.io import AttrDict
|
|
||||||
|
|
||||||
|
|
||||||
tcolors = AttrDict({
|
|
||||||
'OKBLUE': '\033[94m',
|
|
||||||
'HEADER': '\033[95m',
|
|
||||||
'OKGREEN': '\033[92m',
|
|
||||||
'WARNING': '\033[93m',
|
|
||||||
'FAIL': '\033[91m',
|
|
||||||
'ENDC': '\033[0m',
|
|
||||||
'BOLD': '\033[1m',
|
|
||||||
'UNDERLINE': '\033[4m'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class ConsoleLogger():
|
|
||||||
# TODO: merge this with TTS ConsoleLogger
|
|
||||||
def __init__(self):
|
|
||||||
# use these to compare values between iterations
|
|
||||||
self.old_train_loss_dict = None
|
|
||||||
self.old_epoch_loss_dict = None
|
|
||||||
self.old_eval_loss_dict = None
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def get_time(self):
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
def print_epoch_start(self, epoch, max_epoch):
|
|
||||||
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
|
|
||||||
epoch, max_epoch, tcolors.ENDC),
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
def print_train_start(self):
|
|
||||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
|
||||||
|
|
||||||
def print_train_step(self, batch_steps, step, global_step, log_dict,
|
|
||||||
step_time, loader_time, lrG, lrD,
|
|
||||||
loss_dict, avg_loss_dict):
|
|
||||||
indent = " | > "
|
|
||||||
print()
|
|
||||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
|
||||||
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
# print the avg value if given
|
|
||||||
if f'avg_{key}' in avg_loss_dict.keys():
|
|
||||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
|
||||||
else:
|
|
||||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
|
||||||
log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lrG: {lrG}\n{indent}lrD: {lrD}"
|
|
||||||
print(log_text, flush=True)
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
|
||||||
print_dict):
|
|
||||||
indent = " | > "
|
|
||||||
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
|
|
||||||
for key, value in print_dict.items():
|
|
||||||
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
|
||||||
print(log_text, flush=True)
|
|
||||||
|
|
||||||
def print_eval_start(self):
|
|
||||||
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
|
|
||||||
|
|
||||||
def print_eval_step(self, step, loss_dict, avg_loss_dict):
|
|
||||||
indent = " | > "
|
|
||||||
print()
|
|
||||||
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
# print the avg value if given
|
|
||||||
if f'avg_{key}' in avg_loss_dict.keys():
|
|
||||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
|
||||||
else:
|
|
||||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
|
||||||
print(log_text, flush=True)
|
|
||||||
|
|
||||||
def print_epoch_end(self, epoch, avg_loss_dict):
|
|
||||||
indent = " | > "
|
|
||||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
|
|
||||||
tcolors.BOLD, tcolors.ENDC)
|
|
||||||
for key, value in avg_loss_dict.items():
|
|
||||||
# print the avg value if given
|
|
||||||
color = ''
|
|
||||||
sign = '+'
|
|
||||||
diff = 0
|
|
||||||
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
|
|
||||||
diff = value - self.old_eval_loss_dict[key]
|
|
||||||
if diff < 0:
|
|
||||||
color = tcolors.OKGREEN
|
|
||||||
sign = ''
|
|
||||||
elif diff > 0:
|
|
||||||
color = tcolors.FAIL
|
|
||||||
sign = '+'
|
|
||||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
|
||||||
self.old_eval_loss_dict = avg_loss_dict
|
|
||||||
print(log_text, flush=True)
|
|
Loading…
Reference in New Issue