update trainer.py for better logging handling, restoring models and

rename init_ functions with get_
This commit is contained in:
Eren Gölge 2021-05-27 10:24:26 +02:00
parent b8a4af4010
commit e298b8e364
2 changed files with 15 additions and 13 deletions

View File

@ -10,7 +10,11 @@ def main():
# try: # try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training( args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
sys.argv) sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH) trainer = TrainerTTS(args,
config,
c_logger,
tb_logger,
output_path=OUT_PATH)
trainer.fit() trainer.fit()
# except KeyboardInterrupt: # except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH) # remove_experiment_folder(OUT_PATH)

View File

@ -1,12 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import importlib
import logging
import os import os
import sys import sys
import time import time
import traceback import traceback
from logging import StreamHandler
from random import randrange from random import randrange
import logging
import importlib
import numpy as np import numpy as np
import torch import torch
@ -16,19 +17,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets import load_meta_data, TTSDataset from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import init_training from TTS.utils.arguments import init_training
from TTS.tts.utils.visual import plot_spectrogram, plot_alignment
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module from TTS.utils.generic_utils import KeepAverage, count_parameters, find_module, remove_experiment_folder, set_init_dict
from TTS.utils.training import setup_torch_training_env, check_update from TTS.utils.training import check_update, setup_torch_training_env
@dataclass @dataclass
@ -140,9 +141,8 @@ class TrainerTTS:
self.config, args.restore_path, self.model, self.optimizer, self.config, args.restore_path, self.model, self.optimizer,
self.scaler) self.scaler)
if self.use_cuda: # setup scheduler
self.model.cuda() self.scheduler = self.get_scheduler(self.config, self.optimizer)
self.criterion.cuda()
# DISTRUBUTED # DISTRUBUTED
if self.num_gpus > 1: if self.num_gpus > 1:
@ -150,8 +150,7 @@ class TrainerTTS:
# count model size # count model size
num_params = count_parameters(self.model) num_params = count_parameters(self.model)
logging.info("\n > Model has {} parameters".format(num_params), logging.info("\n > Model has {} parameters".format(num_params))
flush=True)
@staticmethod @staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit, def get_model(num_chars: int, num_speakers: int, config: Coqpit,
@ -241,7 +240,6 @@ class TrainerTTS:
try: try:
logging.info(" > Restoring Model...") logging.info(" > Restoring Model...")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
# optimizer restore
logging.info(" > Restoring Optimizer...") logging.info(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision: if "scaler" in checkpoint and config.mixed_precision: