mirror of https://github.com/coqui-ai/TTS.git
update trainer.py for better logging handling, restoring models and
rename init_ functions with get_
This commit is contained in:
parent
b8a4af4010
commit
e298b8e364
|
@ -10,7 +10,11 @@ def main():
|
||||||
# try:
|
# try:
|
||||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
|
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
|
||||||
sys.argv)
|
sys.argv)
|
||||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
|
trainer = TrainerTTS(args,
|
||||||
|
config,
|
||||||
|
c_logger,
|
||||||
|
tb_logger,
|
||||||
|
output_path=OUT_PATH)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
# except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
# remove_experiment_folder(OUT_PATH)
|
# remove_experiment_folder(OUT_PATH)
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from logging import StreamHandler
|
||||||
from random import randrange
|
from random import randrange
|
||||||
import logging
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -16,19 +17,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.datasets import load_meta_data, TTSDataset
|
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||||
from TTS.tts.layers import setup_loss
|
from TTS.tts.layers import setup_loss
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.arguments import init_training
|
from TTS.utils.arguments import init_training
|
||||||
from TTS.tts.utils.visual import plot_spectrogram, plot_alignment
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module
|
from TTS.utils.generic_utils import KeepAverage, count_parameters, find_module, remove_experiment_folder, set_init_dict
|
||||||
from TTS.utils.training import setup_torch_training_env, check_update
|
from TTS.utils.training import check_update, setup_torch_training_env
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -140,9 +141,8 @@ class TrainerTTS:
|
||||||
self.config, args.restore_path, self.model, self.optimizer,
|
self.config, args.restore_path, self.model, self.optimizer,
|
||||||
self.scaler)
|
self.scaler)
|
||||||
|
|
||||||
if self.use_cuda:
|
# setup scheduler
|
||||||
self.model.cuda()
|
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||||
self.criterion.cuda()
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
|
@ -150,8 +150,7 @@ class TrainerTTS:
|
||||||
|
|
||||||
# count model size
|
# count model size
|
||||||
num_params = count_parameters(self.model)
|
num_params = count_parameters(self.model)
|
||||||
logging.info("\n > Model has {} parameters".format(num_params),
|
logging.info("\n > Model has {} parameters".format(num_params))
|
||||||
flush=True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||||
|
@ -241,7 +240,6 @@ class TrainerTTS:
|
||||||
try:
|
try:
|
||||||
logging.info(" > Restoring Model...")
|
logging.info(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
# optimizer restore
|
|
||||||
logging.info(" > Restoring Optimizer...")
|
logging.info(" > Restoring Optimizer...")
|
||||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
if "scaler" in checkpoint and config.mixed_precision:
|
if "scaler" in checkpoint and config.mixed_precision:
|
||||||
|
|
Loading…
Reference in New Issue