`TrainerAbstract` and related updates for `TrainerTTS`

This commit is contained in:
Eren Gölge 2021-06-08 14:18:18 +02:00
parent 00c82c516d
commit c754a0e17d
3 changed files with 771 additions and 654 deletions

View File

@ -2,7 +2,7 @@ import os
import sys
import traceback
from TTS.trainer import TrainerTTS
from TTS.tts.trainer_tts import TrainerTTS
from TTS.utils.arguments import init_training
from TTS.utils.generic_utils import remove_experiment_folder

View File

@ -1,39 +1,23 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import os
import time
from argparse import Namespace
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, TypeVar
import torch
from coqpit import Coqpit
# DISTRIBUTED
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.training import check_update, setup_torch_training_env
_DataLoader = TypeVar("_DataLoader")
@dataclass
class TrainingArgs(Coqpit):
"""Trainer arguments that are parsed externally (e.g. CLI)"""
continue_path: str = field(
default="",
metadata={
@ -58,676 +42,100 @@ class TrainingArgs(Coqpit):
# pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerTTS:
use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(
self,
args: Union[Coqpit, Namespace],
config: Coqpit,
c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None,
model: nn.Module = None,
output_path: str = None,
) -> None:
self.args = args
self.config = config
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None:
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else:
self.tb_logger = tb_logger
self.output_path = output_path
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None
self.keep_avg_eval = None
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# model, audio processor, datasets, loss
# init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict())
# init character processor
self.model_characters = self.get_character_processor(self.config)
# load dataset samples
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
# default speaker manager
self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train)
# init TTS model
if model is not None:
self.model = model
else:
self.model = self.get_model(
len(self.model_characters),
self.speaker_manager.num_speakers,
self.config,
self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None,
)
# setup criterion
self.criterion = self.get_criterion(self.config)
# DISTRUBUTED
if self.num_gpus > 1:
init_distributed(
args.rank,
self.num_gpus,
args.group_id,
self.config.distributed_backend,
self.config.distributed_url,
)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)
if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer, self.scaler
)
# setup scheduler
self.scheduler = self.get_scheduler(self.config, self.optimizer)
# DISTRUBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
# count model size
num_params = count_parameters(self.model)
print("\n > Model has {} parameters".format(num_params))
class TrainerAbstract(ABC):
@staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module:
model = setup_model(num_chars, num_speakers, config, d_vector_dim)
return model
def _is_apex_available():
return importlib.util.find_spec("apex") is not None
@staticmethod
@abstractmethod
def get_model(*args, **kwargs) -> nn.Module:
pass
@staticmethod
@abstractmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
optimizer_name = config.optimizer
optimizer_params = config.optimizer_params
if optimizer_name.lower() == "radam":
module = importlib.import_module("TTS.utils.radam")
optimizer = getattr(module, "RAdam")
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
@staticmethod
def get_character_processor(config: Coqpit) -> str:
# setup custom characters if set in config file.
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters.to_dict())
else:
from TTS.tts.utils.text.symbols import phonemes, symbols
model_characters = phonemes if config.use_phonemes else symbols
return model_characters
@staticmethod
def get_speaker_manager(
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
) -> SpeakerManager:
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
return speaker_manager
pass
@staticmethod
@abstractmethod
def get_scheduler(
config: Coqpit, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None:
return None
if lr_scheduler.lower() == "noamlr":
from TTS.utils.training import NoamLR
scheduler = NoamLR
else:
scheduler = getattr(torch.optim, lr_scheduler)
return scheduler(optimizer, **lr_scheduler_params)
pass
@staticmethod
@abstractmethod
def get_criterion(config: Coqpit) -> nn.Module:
return setup_loss(config)
pass
def restore_model(
self,
config: Coqpit,
restore_path: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.cuda.amp.GradScaler = None,
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except (KeyError, RuntimeError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
model.load_state_dict(model_dict)
del model_dict
@abstractmethod
def restore_model(self, *args, **kwargs) -> Tuple:
pass
for group in optimizer.param_groups:
group["lr"] = self.config.lr
print(
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step
@abstractmethod
def get_train_dataloader(self, *args, **kwargs) -> _DataLoader:
pass
def _get_loader(
self,
r: int,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
speaker_ids: Union[Dict, List],
d_vectors: Union[Dict, List],
) -> DataLoader:
if is_eval and not self.config.run_eval:
loader = None
else:
dataset = TTSDataset(
outputs_per_step=r,
text_cleaner=self.config.text_cleaner,
compute_linear_spec=self.config.model.lower() == "tacotron",
meta_data=data_items,
ap=ap,
tp=self.config.characters,
add_blank=self.config["add_blank"],
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
min_seq_len=self.config.min_seq_len,
max_seq_len=self.config.max_seq_len,
phoneme_cache_path=self.config.phoneme_cache_path,
use_phonemes=self.config.use_phonemes,
phoneme_language=self.config.phoneme_language,
enable_eos_bos=self.config.enable_eos_bos_chars,
use_noise_augment=not is_eval,
verbose=verbose,
speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None,
d_vector_mapping=d_vectors
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
else None,
)
if self.config.use_phonemes and self.config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(self.config.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def get_train_dataloader(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
def get_eval_dataloder(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
@abstractmethod
def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader:
pass
@abstractmethod
def format_batch(self, batch: List) -> Dict:
# setup input batch
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
item_idx = batch[7]
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# compute durations from attention masks
durations = None
if attn_mask is not None:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
dur[c_idxs] = counts
# smooth the durations and set any 0 duration to 1
# by cutting off from the largest duration indeces.
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
# dispatch batch to GPU
if self.use_cuda:
text_input = to_cuda(text_input)
text_lengths = to_cuda(text_lengths)
mel_input = to_cuda(mel_input)
mel_lengths = to_cuda(mel_lengths)
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
stop_targets = to_cuda(stop_targets)
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
durations = to_cuda(durations) if attn_mask is not None else None
if speaker_ids is not None:
speaker_ids = to_cuda(speaker_ids)
if d_vectors is not None:
d_vectors = to_cuda(d_vectors)
return {
"text_input": text_input,
"text_lengths": text_lengths,
"speaker_names": speaker_names,
"mel_input": mel_input,
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,
"d_vectors": d_vectors,
"max_text_length": max_text_length,
"max_spec_length": max_spec_length,
"item_idx": item_idx,
}
pass
@abstractmethod
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.train_step(batch, criterion)
return self.model.train_step(batch, criterion)
pass
@abstractmethod
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start()
step_start_time = time.time()
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
# zero-out optimizer
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self._train_step(batch, self.criterion)
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
# optimizer step
if self.config.mixed_precision:
# model optimizer step in mixed precision mode
self.scaler.scale(loss_dict["loss"]).backward()
self.scaler.unscale_(self.optimizer)
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# main model optimizer step
loss_dict["loss"].backward()
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.optimizer.step()
step_time = time.time() - step_start_time
# setup lr
if self.config.lr_scheduler:
self.scheduler.step()
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
self.keep_avg_train.update_values(update_train_values)
# print training progress
current_lr = self.optimizer.param_groups[0]["lr"]
if self.total_steps_done % self.config.print_step == 0:
log_dict = {
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
"max_text_length": [batch["max_text_length"], 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
)
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
}
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.save_step == 0:
if self.config.checkpoint:
# save model
save_checkpoint(
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
model_loss=loss_dict["loss"],
characters=self.model_characters,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
# training visualizations
if hasattr(self.model, "module"):
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
else:
figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
self.total_steps_done += 1
self.on_train_step_end()
return outputs, loss_dict
pass
@abstractmethod
def train_epoch(self) -> None:
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
pass
@abstractmethod
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.eval_step(batch, self.criterion)
return self.model.eval_step(batch, self.criterion)
pass
@abstractmethod
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad():
step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self._eval_step(batch)
step_time = time.time() - step_start_time
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
update_eval_values["avg_step_time"] = step_time
self.keep_avg_eval.update_values(update_eval_values)
if self.config.print_eval:
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
return outputs, loss_dict
pass
@abstractmethod
def eval_epoch(self) -> None:
self.model.eval()
self.c_logger.print_eval_start()
loader_start_time = time.time()
batch = None
for cur_step, batch in enumerate(self.eval_loader):
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch.
if self.args.rank == 0:
if hasattr(self.model, "module"):
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
else:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
pass
def test_run(
self,
) -> None:
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
self.model,
sen,
self.config,
self.use_cuda,
self.ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_aux_inputs(self) -> Dict:
# setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None
# setup d_vector
d_vector = (
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
else None
)
# setup style_mel
if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
return aux_inputs
@abstractmethod
def test_run(self) -> None:
pass
@abstractmethod
def fit(self) -> None:
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders
self.train_loader = self.get_train_dataloader(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_ids=self.speaker_manager.speaker_ids,
d_vectors=self.speaker_manager.d_vectors,
)
self.eval_loader = (
self.get_eval_dataloder(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_ids=self.speaker_manager.speaker_ids,
d_vectors=self.speaker_manager.d_vectors,
)
if self.config.run_eval
else None
)
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
self.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
)
self.save_best_model()
self.on_epoch_end()
pass
@abstractmethod
def save_best_model(self) -> None:
self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.best_loss,
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
self.model_characters,
keep_all_best=self.config.keep_all_best,
keep_after=self.config.keep_after,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
pass
@staticmethod
def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
@abstractmethod
def on_epoch_start(self) -> None:
pass
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"):
self.model.on_epoch_start(self)
@abstractmethod
def on_epoch_end(self) -> None:
pass
if hasattr(self.criterion, "on_epoch_start"):
self.criterion.on_epoch_start(self)
@abstractmethod
def on_train_step_start(self) -> None:
pass
if hasattr(self.optimizer, "on_epoch_start"):
self.optimizer.on_epoch_start(self)
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_end"):
self.model.on_epoch_end(self)
if hasattr(self.criterion, "on_epoch_end"):
self.criterion.on_epoch_end(self)
if hasattr(self.optimizer, "on_epoch_end"):
self.optimizer.on_epoch_end(self)
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_train_step_start"):
self.model.on_train_step_start(self)
if hasattr(self.criterion, "on_train_step_start"):
self.criterion.on_train_step_start(self)
if hasattr(self.optimizer, "on_train_step_start"):
self.optimizer.on_train_step_start(self)
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_train_step_end"):
self.model.on_train_step_end(self)
if hasattr(self.criterion, "on_train_step_end"):
self.criterion.on_train_step_end(self)
if hasattr(self.optimizer, "on_train_step_end"):
self.optimizer.on_train_step_end(self)
@abstractmethod
def on_train_step_end(self) -> None:
pass

709
TTS/tts/trainer_tts.py Normal file
View File

@ -0,0 +1,709 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import os
import time
from argparse import Namespace
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
# DISTRIBUTED
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.trainer import TrainerAbstract
from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.training import check_update, setup_torch_training_env
# pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerTTS(TrainerAbstract):
use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(
self,
args: Union[Coqpit, Namespace],
config: Coqpit,
c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None,
model: nn.Module = None,
output_path: str = None,
) -> None:
self.args = args
self.config = config
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None:
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else:
self.tb_logger = tb_logger
self.output_path = output_path
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, "test_audios")
self.keep_avg_train = None
self.keep_avg_eval = None
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# model, audio processor, datasets, loss
# init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict())
# init character processor
self.model_characters = self.get_character_processor(self.config)
# load dataset samples
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
# default speaker manager
self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train)
# init TTS model
if model is not None:
self.model = model
else:
self.model = self.get_model(
len(self.model_characters),
self.speaker_manager.num_speakers,
self.config,
self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None,
)
# setup criterion
self.criterion = self.get_criterion(self.config)
# DISTRUBUTED
if self.num_gpus > 1:
init_distributed(
args.rank,
self.num_gpus,
args.group_id,
self.config.distributed_backend,
self.config.distributed_url,
)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)
if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer, self.scaler
)
# setup scheduler
self.scheduler = self.get_scheduler(self.config, self.optimizer)
# DISTRUBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
# count model size
num_params = count_parameters(self.model)
print("\n > Model has {} parameters".format(num_params))
@staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module:
model = setup_model(num_chars, num_speakers, config, d_vector_dim)
return model
@staticmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
optimizer_name = config.optimizer
optimizer_params = config.optimizer_params
if optimizer_name.lower() == "radam":
module = importlib.import_module("TTS.utils.radam")
optimizer = getattr(module, "RAdam")
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
@staticmethod
def get_character_processor(config: Coqpit) -> str:
# setup custom characters if set in config file.
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters.to_dict())
else:
from TTS.tts.utils.text.symbols import phonemes, symbols
model_characters = phonemes if config.use_phonemes else symbols
return model_characters
@staticmethod
def get_speaker_manager(
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
) -> SpeakerManager:
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
return speaker_manager
@staticmethod
def get_scheduler(
config: Coqpit, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None:
return None
if lr_scheduler.lower() == "noamlr":
from TTS.utils.training import NoamLR
scheduler = NoamLR
else:
scheduler = getattr(torch.optim, lr_scheduler)
return scheduler(optimizer, **lr_scheduler_params)
@staticmethod
def get_criterion(config: Coqpit) -> nn.Module:
return setup_loss(config)
def restore_model(
self,
config: Coqpit,
restore_path: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.cuda.amp.GradScaler = None,
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except (KeyError, RuntimeError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group["lr"] = self.config.lr
print(
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step
def _get_loader(
self,
r: int,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
speaker_ids: Union[Dict, List],
d_vectors: Union[Dict, List],
) -> DataLoader:
if is_eval and not self.config.run_eval:
loader = None
else:
dataset = TTSDataset(
outputs_per_step=r,
text_cleaner=self.config.text_cleaner,
compute_linear_spec=self.config.model.lower() == "tacotron",
meta_data=data_items,
ap=ap,
tp=self.config.characters,
add_blank=self.config["add_blank"],
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
min_seq_len=self.config.min_seq_len,
max_seq_len=self.config.max_seq_len,
phoneme_cache_path=self.config.phoneme_cache_path,
use_phonemes=self.config.use_phonemes,
phoneme_language=self.config.phoneme_language,
enable_eos_bos=self.config.enable_eos_bos_chars,
use_noise_augment=not is_eval,
verbose=verbose,
speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None,
d_vector_mapping=d_vectors
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
else None,
)
if self.config.use_phonemes and self.config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(self.config.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def get_train_dataloader(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
def get_eval_dataloder(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
def format_batch(self, batch: List) -> Dict:
# setup input batch
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
item_idx = batch[7]
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# compute durations from attention masks
durations = None
if attn_mask is not None:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
dur[c_idxs] = counts
# smooth the durations and set any 0 duration to 1
# by cutting off from the largest duration indeces.
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
# dispatch batch to GPU
if self.use_cuda:
text_input = to_cuda(text_input)
text_lengths = to_cuda(text_lengths)
mel_input = to_cuda(mel_input)
mel_lengths = to_cuda(mel_lengths)
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
stop_targets = to_cuda(stop_targets)
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
durations = to_cuda(durations) if attn_mask is not None else None
if speaker_ids is not None:
speaker_ids = to_cuda(speaker_ids)
if d_vectors is not None:
d_vectors = to_cuda(d_vectors)
return {
"text_input": text_input,
"text_lengths": text_lengths,
"speaker_names": speaker_names,
"mel_input": mel_input,
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,
"d_vectors": d_vectors,
"max_text_length": max_text_length,
"max_spec_length": max_spec_length,
"item_idx": item_idx,
}
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.train_step(batch, criterion)
return self.model.train_step(batch, criterion)
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start()
step_start_time = time.time()
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
# zero-out optimizer
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self._train_step(batch, self.criterion)
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
# optimizer step
if self.config.mixed_precision:
# model optimizer step in mixed precision mode
self.scaler.scale(loss_dict["loss"]).backward()
self.scaler.unscale_(self.optimizer)
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# main model optimizer step
loss_dict["loss"].backward()
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
self.optimizer.step()
step_time = time.time() - step_start_time
# setup lr
if self.config.lr_scheduler:
self.scheduler.step()
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
self.keep_avg_train.update_values(update_train_values)
# print training progress
current_lr = self.optimizer.param_groups[0]["lr"]
if self.total_steps_done % self.config.print_step == 0:
log_dict = {
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
"max_text_length": [batch["max_text_length"], 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
)
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
}
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.save_step == 0:
if self.config.checkpoint:
# save model
save_checkpoint(
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
model_loss=loss_dict["loss"],
characters=self.model_characters,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
# training visualizations
if hasattr(self.model, "module"):
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
else:
figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
self.total_steps_done += 1
self.on_train_step_end()
return outputs, loss_dict
def train_epoch(self) -> None:
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.eval_step(batch, self.criterion)
return self.model.eval_step(batch, self.criterion)
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad():
step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self._eval_step(batch)
step_time = time.time() - step_start_time
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
update_eval_values["avg_step_time"] = step_time
self.keep_avg_eval.update_values(update_eval_values)
if self.config.print_eval:
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
return outputs, loss_dict
def eval_epoch(self) -> None:
self.model.eval()
self.c_logger.print_eval_start()
loader_start_time = time.time()
batch = None
for cur_step, batch in enumerate(self.eval_loader):
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch.
if self.args.rank == 0:
if hasattr(self.model, "module"):
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
else:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
def test_run(
self,
) -> None:
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
self.model,
sen,
self.config,
self.use_cuda,
self.ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_aux_inputs(self) -> Dict:
# setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None
# setup d_vector
d_vector = (
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
else None
)
# setup style_mel
if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
return aux_inputs
def fit(self) -> None:
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders
self.train_loader = self.get_train_dataloader(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_ids=self.speaker_manager.speaker_ids,
d_vectors=self.speaker_manager.d_vectors,
)
self.eval_loader = (
self.get_eval_dataloder(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_ids=self.speaker_manager.speaker_ids,
d_vectors=self.speaker_manager.d_vectors,
)
if self.config.run_eval
else None
)
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
self.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
)
self.save_best_model()
self.on_epoch_end()
def save_best_model(self) -> None:
self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.best_loss,
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
self.model_characters,
keep_all_best=self.config.keep_all_best,
keep_after=self.config.keep_after,
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
@staticmethod
def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"):
self.model.on_epoch_start(self)
if hasattr(self.criterion, "on_epoch_start"):
self.criterion.on_epoch_start(self)
if hasattr(self.optimizer, "on_epoch_start"):
self.optimizer.on_epoch_start(self)
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_end"):
self.model.on_epoch_end(self)
if hasattr(self.criterion, "on_epoch_end"):
self.criterion.on_epoch_end(self)
if hasattr(self.optimizer, "on_epoch_end"):
self.optimizer.on_epoch_end(self)
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_train_step_start"):
self.model.on_train_step_start(self)
if hasattr(self.criterion, "on_train_step_start"):
self.criterion.on_train_step_start(self)
if hasattr(self.optimizer, "on_train_step_start"):
self.optimizer.on_train_step_start(self)
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_train_step_end"):
self.model.on_train_step_end(self)
if hasattr(self.criterion, "on_train_step_end"):
self.criterion.on_train_step_end(self)
if hasattr(self.optimizer, "on_train_step_end"):
self.optimizer.on_train_step_end(self)