mirror of https://github.com/coqui-ai/TTS.git
`TrainerAbstract` and related updates for `TrainerTTS`
This commit is contained in:
parent
f077a356e0
commit
64f0f57757
|
@ -2,7 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from TTS.trainer import TrainerTTS
|
from TTS.tts.trainer_tts import TrainerTTS
|
||||||
from TTS.utils.arguments import init_training
|
from TTS.utils.arguments import init_training
|
||||||
from TTS.utils.generic_utils import remove_experiment_folder
|
from TTS.utils.generic_utils import remove_experiment_folder
|
||||||
|
|
||||||
|
|
714
TTS/trainer.py
714
TTS/trainer.py
|
@ -1,39 +1,23 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
from abc import ABC, abstractmethod
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from argparse import Namespace
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from TTS.tts.datasets import TTSDataset, load_meta_data
|
_DataLoader = TypeVar("_DataLoader")
|
||||||
from TTS.tts.layers import setup_loss
|
|
||||||
from TTS.tts.models import setup_model
|
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
|
||||||
from TTS.tts.utils.text.symbols import make_symbols
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.distribute import init_distributed
|
|
||||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
|
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
|
||||||
from TTS.utils.training import check_update, setup_torch_training_env
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs(Coqpit):
|
class TrainingArgs(Coqpit):
|
||||||
|
"""Trainer arguments that are parsed externally (e.g. CLI)"""
|
||||||
|
|
||||||
continue_path: str = field(
|
continue_path: str = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -58,676 +42,100 @@ class TrainingArgs(Coqpit):
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
||||||
class TrainerTTS:
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args: Union[Coqpit, Namespace],
|
|
||||||
config: Coqpit,
|
|
||||||
c_logger: ConsoleLogger = None,
|
|
||||||
tb_logger: TensorboardLogger = None,
|
|
||||||
model: nn.Module = None,
|
|
||||||
output_path: str = None,
|
|
||||||
) -> None:
|
|
||||||
self.args = args
|
|
||||||
self.config = config
|
|
||||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
|
||||||
if tb_logger is None:
|
|
||||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
|
||||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
|
||||||
else:
|
|
||||||
self.tb_logger = tb_logger
|
|
||||||
self.output_path = output_path
|
|
||||||
|
|
||||||
self.total_steps_done = 0
|
class TrainerAbstract(ABC):
|
||||||
self.epochs_done = 0
|
|
||||||
self.restore_step = 0
|
|
||||||
self.best_loss = float("inf")
|
|
||||||
self.train_loader = None
|
|
||||||
self.eval_loader = None
|
|
||||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
|
||||||
|
|
||||||
self.keep_avg_train = None
|
|
||||||
self.keep_avg_eval = None
|
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
|
||||||
self._setup_logger_config(log_file)
|
|
||||||
|
|
||||||
# model, audio processor, datasets, loss
|
|
||||||
# init audio processor
|
|
||||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
|
||||||
|
|
||||||
# init character processor
|
|
||||||
self.model_characters = self.get_character_processor(self.config)
|
|
||||||
|
|
||||||
# load dataset samples
|
|
||||||
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
|
||||||
|
|
||||||
# default speaker manager
|
|
||||||
self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train)
|
|
||||||
|
|
||||||
# init TTS model
|
|
||||||
if model is not None:
|
|
||||||
self.model = model
|
|
||||||
else:
|
|
||||||
self.model = self.get_model(
|
|
||||||
len(self.model_characters),
|
|
||||||
self.speaker_manager.num_speakers,
|
|
||||||
self.config,
|
|
||||||
self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup criterion
|
|
||||||
self.criterion = self.get_criterion(self.config)
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
if self.num_gpus > 1:
|
|
||||||
init_distributed(
|
|
||||||
args.rank,
|
|
||||||
self.num_gpus,
|
|
||||||
args.group_id,
|
|
||||||
self.config.distributed_backend,
|
|
||||||
self.config.distributed_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_cuda:
|
|
||||||
self.model.cuda()
|
|
||||||
self.criterion.cuda()
|
|
||||||
|
|
||||||
# scalers for mixed precision training
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
|
||||||
|
|
||||||
# setup optimizer
|
|
||||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
|
||||||
|
|
||||||
if self.args.restore_path:
|
|
||||||
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
|
||||||
self.config, args.restore_path, self.model, self.optimizer, self.scaler
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup scheduler
|
|
||||||
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
if self.num_gpus > 1:
|
|
||||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
|
||||||
|
|
||||||
# count model size
|
|
||||||
num_params = count_parameters(self.model)
|
|
||||||
print("\n > Model has {} parameters".format(num_params))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module:
|
def _is_apex_available():
|
||||||
model = setup_model(num_chars, num_speakers, config, d_vector_dim)
|
return importlib.util.find_spec("apex") is not None
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(*args, **kwargs) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
||||||
optimizer_name = config.optimizer
|
pass
|
||||||
optimizer_params = config.optimizer_params
|
|
||||||
if optimizer_name.lower() == "radam":
|
|
||||||
module = importlib.import_module("TTS.utils.radam")
|
|
||||||
optimizer = getattr(module, "RAdam")
|
|
||||||
else:
|
|
||||||
optimizer = getattr(torch.optim, optimizer_name)
|
|
||||||
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_character_processor(config: Coqpit) -> str:
|
|
||||||
# setup custom characters if set in config file.
|
|
||||||
# TODO: implement CharacterProcessor
|
|
||||||
if config.characters is not None:
|
|
||||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
|
||||||
else:
|
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
|
||||||
model_characters = phonemes if config.use_phonemes else symbols
|
|
||||||
return model_characters
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_speaker_manager(
|
|
||||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
|
||||||
) -> SpeakerManager:
|
|
||||||
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
|
|
||||||
return speaker_manager
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
config: Coqpit, optimizer: torch.optim.Optimizer
|
config: Coqpit, optimizer: torch.optim.Optimizer
|
||||||
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||||
lr_scheduler = config.lr_scheduler
|
pass
|
||||||
lr_scheduler_params = config.lr_scheduler_params
|
|
||||||
if lr_scheduler is None:
|
|
||||||
return None
|
|
||||||
if lr_scheduler.lower() == "noamlr":
|
|
||||||
from TTS.utils.training import NoamLR
|
|
||||||
|
|
||||||
scheduler = NoamLR
|
|
||||||
else:
|
|
||||||
scheduler = getattr(torch.optim, lr_scheduler)
|
|
||||||
return scheduler(optimizer, **lr_scheduler_params)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
def get_criterion(config: Coqpit) -> nn.Module:
|
def get_criterion(config: Coqpit) -> nn.Module:
|
||||||
return setup_loss(config)
|
pass
|
||||||
|
|
||||||
def restore_model(
|
@abstractmethod
|
||||||
self,
|
def restore_model(self, *args, **kwargs) -> Tuple:
|
||||||
config: Coqpit,
|
pass
|
||||||
restore_path: str,
|
|
||||||
model: nn.Module,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
scaler: torch.cuda.amp.GradScaler = None,
|
|
||||||
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
|
||||||
checkpoint = torch.load(restore_path)
|
|
||||||
try:
|
|
||||||
print(" > Restoring Model...")
|
|
||||||
model.load_state_dict(checkpoint["model"])
|
|
||||||
print(" > Restoring Optimizer...")
|
|
||||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
||||||
if "scaler" in checkpoint and config.mixed_precision:
|
|
||||||
print(" > Restoring AMP Scaler...")
|
|
||||||
scaler.load_state_dict(checkpoint["scaler"])
|
|
||||||
except (KeyError, RuntimeError):
|
|
||||||
print(" > Partial model initialization...")
|
|
||||||
model_dict = model.state_dict()
|
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
|
||||||
model.load_state_dict(model_dict)
|
|
||||||
del model_dict
|
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
@abstractmethod
|
||||||
group["lr"] = self.config.lr
|
def get_train_dataloader(self, *args, **kwargs) -> _DataLoader:
|
||||||
print(
|
pass
|
||||||
" > Model restored from step %d" % checkpoint["step"],
|
|
||||||
)
|
|
||||||
restore_step = checkpoint["step"]
|
|
||||||
return model, optimizer, scaler, restore_step
|
|
||||||
|
|
||||||
def _get_loader(
|
@abstractmethod
|
||||||
self,
|
def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader:
|
||||||
r: int,
|
pass
|
||||||
ap: AudioProcessor,
|
|
||||||
is_eval: bool,
|
|
||||||
data_items: List,
|
|
||||||
verbose: bool,
|
|
||||||
speaker_ids: Union[Dict, List],
|
|
||||||
d_vectors: Union[Dict, List],
|
|
||||||
) -> DataLoader:
|
|
||||||
if is_eval and not self.config.run_eval:
|
|
||||||
loader = None
|
|
||||||
else:
|
|
||||||
dataset = TTSDataset(
|
|
||||||
outputs_per_step=r,
|
|
||||||
text_cleaner=self.config.text_cleaner,
|
|
||||||
compute_linear_spec=self.config.model.lower() == "tacotron",
|
|
||||||
meta_data=data_items,
|
|
||||||
ap=ap,
|
|
||||||
tp=self.config.characters,
|
|
||||||
add_blank=self.config["add_blank"],
|
|
||||||
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
|
|
||||||
min_seq_len=self.config.min_seq_len,
|
|
||||||
max_seq_len=self.config.max_seq_len,
|
|
||||||
phoneme_cache_path=self.config.phoneme_cache_path,
|
|
||||||
use_phonemes=self.config.use_phonemes,
|
|
||||||
phoneme_language=self.config.phoneme_language,
|
|
||||||
enable_eos_bos=self.config.enable_eos_bos_chars,
|
|
||||||
use_noise_augment=not is_eval,
|
|
||||||
verbose=verbose,
|
|
||||||
speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None,
|
|
||||||
d_vector_mapping=d_vectors
|
|
||||||
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
|
||||||
dataset.compute_input_seq(self.config.num_loader_workers)
|
|
||||||
dataset.sort_items()
|
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
|
|
||||||
loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=dataset.collate_fn,
|
|
||||||
drop_last=False,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
|
|
||||||
pin_memory=False,
|
|
||||||
)
|
|
||||||
return loader
|
|
||||||
|
|
||||||
def get_train_dataloader(
|
|
||||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
|
||||||
) -> DataLoader:
|
|
||||||
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
|
|
||||||
|
|
||||||
def get_eval_dataloder(
|
|
||||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
|
||||||
) -> DataLoader:
|
|
||||||
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def format_batch(self, batch: List) -> Dict:
|
def format_batch(self, batch: List) -> Dict:
|
||||||
# setup input batch
|
pass
|
||||||
text_input = batch[0]
|
|
||||||
text_lengths = batch[1]
|
|
||||||
speaker_names = batch[2]
|
|
||||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
|
||||||
mel_input = batch[4]
|
|
||||||
mel_lengths = batch[5]
|
|
||||||
stop_targets = batch[6]
|
|
||||||
item_idx = batch[7]
|
|
||||||
d_vectors = batch[8]
|
|
||||||
speaker_ids = batch[9]
|
|
||||||
attn_mask = batch[10]
|
|
||||||
max_text_length = torch.max(text_lengths.float())
|
|
||||||
max_spec_length = torch.max(mel_lengths.float())
|
|
||||||
|
|
||||||
# compute durations from attention masks
|
|
||||||
durations = None
|
|
||||||
if attn_mask is not None:
|
|
||||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
|
||||||
for idx, am in enumerate(attn_mask):
|
|
||||||
# compute raw durations
|
|
||||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
|
||||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
|
||||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
|
||||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
|
||||||
dur[c_idxs] = counts
|
|
||||||
# smooth the durations and set any 0 duration to 1
|
|
||||||
# by cutting off from the largest duration indeces.
|
|
||||||
extra_frames = dur.sum() - mel_lengths[idx]
|
|
||||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
|
||||||
dur[largest_idxs] -= 1
|
|
||||||
assert (
|
|
||||||
dur.sum() == mel_lengths[idx]
|
|
||||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
|
||||||
durations[idx, : text_lengths[idx]] = dur
|
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per iteration.
|
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
|
||||||
|
|
||||||
# dispatch batch to GPU
|
|
||||||
if self.use_cuda:
|
|
||||||
text_input = to_cuda(text_input)
|
|
||||||
text_lengths = to_cuda(text_lengths)
|
|
||||||
mel_input = to_cuda(mel_input)
|
|
||||||
mel_lengths = to_cuda(mel_lengths)
|
|
||||||
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
|
|
||||||
stop_targets = to_cuda(stop_targets)
|
|
||||||
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
|
|
||||||
durations = to_cuda(durations) if attn_mask is not None else None
|
|
||||||
if speaker_ids is not None:
|
|
||||||
speaker_ids = to_cuda(speaker_ids)
|
|
||||||
if d_vectors is not None:
|
|
||||||
d_vectors = to_cuda(d_vectors)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text_input": text_input,
|
|
||||||
"text_lengths": text_lengths,
|
|
||||||
"speaker_names": speaker_names,
|
|
||||||
"mel_input": mel_input,
|
|
||||||
"mel_lengths": mel_lengths,
|
|
||||||
"linear_input": linear_input,
|
|
||||||
"stop_targets": stop_targets,
|
|
||||||
"attn_mask": attn_mask,
|
|
||||||
"durations": durations,
|
|
||||||
"speaker_ids": speaker_ids,
|
|
||||||
"d_vectors": d_vectors,
|
|
||||||
"max_text_length": max_text_length,
|
|
||||||
"max_spec_length": max_spec_length,
|
|
||||||
"item_idx": item_idx,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
if hasattr(self.model, "module"):
|
pass
|
||||||
return self.model.module.train_step(batch, criterion)
|
|
||||||
return self.model.train_step(batch, criterion)
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
self.on_train_step_start()
|
pass
|
||||||
step_start_time = time.time()
|
|
||||||
|
|
||||||
# format data
|
|
||||||
batch = self.format_batch(batch)
|
|
||||||
loader_time = time.time() - loader_start_time
|
|
||||||
|
|
||||||
# zero-out optimizer
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
|
||||||
outputs, loss_dict = self._train_step(batch, self.criterion)
|
|
||||||
|
|
||||||
# check nan loss
|
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
|
||||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
|
||||||
|
|
||||||
# optimizer step
|
|
||||||
if self.config.mixed_precision:
|
|
||||||
# model optimizer step in mixed precision mode
|
|
||||||
self.scaler.scale(loss_dict["loss"]).backward()
|
|
||||||
self.scaler.unscale_(self.optimizer)
|
|
||||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
|
||||||
# main model optimizer step
|
|
||||||
loss_dict["loss"].backward()
|
|
||||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
|
||||||
|
|
||||||
# setup lr
|
|
||||||
if self.config.lr_scheduler:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
# detach loss values
|
|
||||||
loss_dict_new = dict()
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
loss_dict_new[key] = value
|
|
||||||
else:
|
|
||||||
loss_dict_new[key] = value.item()
|
|
||||||
loss_dict = loss_dict_new
|
|
||||||
|
|
||||||
# update avg stats
|
|
||||||
update_train_values = dict()
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
update_train_values["avg_" + key] = value
|
|
||||||
update_train_values["avg_loader_time"] = loader_time
|
|
||||||
update_train_values["avg_step_time"] = step_time
|
|
||||||
self.keep_avg_train.update_values(update_train_values)
|
|
||||||
|
|
||||||
# print training progress
|
|
||||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
|
||||||
if self.total_steps_done % self.config.print_step == 0:
|
|
||||||
log_dict = {
|
|
||||||
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
|
|
||||||
"max_text_length": [batch["max_text_length"], 1],
|
|
||||||
"step_time": [step_time, 4],
|
|
||||||
"loader_time": [loader_time, 2],
|
|
||||||
"current_lr": current_lr,
|
|
||||||
}
|
|
||||||
self.c_logger.print_train_step(
|
|
||||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.rank == 0:
|
|
||||||
# Plot Training Iter Stats
|
|
||||||
# reduce TB load
|
|
||||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
|
||||||
iter_stats = {
|
|
||||||
"lr": current_lr,
|
|
||||||
"grad_norm": grad_norm,
|
|
||||||
"step_time": step_time,
|
|
||||||
}
|
|
||||||
iter_stats.update(loss_dict)
|
|
||||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
|
||||||
|
|
||||||
if self.total_steps_done % self.config.save_step == 0:
|
|
||||||
if self.config.checkpoint:
|
|
||||||
# save model
|
|
||||||
save_checkpoint(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
self.total_steps_done,
|
|
||||||
self.epochs_done,
|
|
||||||
self.config.r,
|
|
||||||
self.output_path,
|
|
||||||
model_loss=loss_dict["loss"],
|
|
||||||
characters=self.model_characters,
|
|
||||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
|
||||||
)
|
|
||||||
# training visualizations
|
|
||||||
if hasattr(self.model, "module"):
|
|
||||||
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
|
||||||
else:
|
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
|
||||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
|
||||||
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
|
||||||
self.total_steps_done += 1
|
|
||||||
self.on_train_step_end()
|
|
||||||
return outputs, loss_dict
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
self.model.train()
|
pass
|
||||||
epoch_start_time = time.time()
|
|
||||||
if self.use_cuda:
|
|
||||||
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
|
||||||
else:
|
|
||||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
|
||||||
self.c_logger.print_train_start()
|
|
||||||
loader_start_time = time.time()
|
|
||||||
for cur_step, batch in enumerate(self.train_loader):
|
|
||||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
|
||||||
epoch_time = time.time() - epoch_start_time
|
|
||||||
# Plot self.epochs_done Stats
|
|
||||||
if self.args.rank == 0:
|
|
||||||
epoch_stats = {"epoch_time": epoch_time}
|
|
||||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
|
||||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
|
||||||
if self.config.tb_model_param_stats:
|
|
||||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
||||||
if hasattr(self.model, "module"):
|
pass
|
||||||
return self.model.module.eval_step(batch, self.criterion)
|
|
||||||
return self.model.eval_step(batch, self.criterion)
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
with torch.no_grad():
|
pass
|
||||||
step_start_time = time.time()
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
|
||||||
outputs, loss_dict = self._eval_step(batch)
|
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
|
||||||
|
|
||||||
# detach loss values
|
|
||||||
loss_dict_new = dict()
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
loss_dict_new[key] = value
|
|
||||||
else:
|
|
||||||
loss_dict_new[key] = value.item()
|
|
||||||
loss_dict = loss_dict_new
|
|
||||||
|
|
||||||
# update avg stats
|
|
||||||
update_eval_values = dict()
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
update_eval_values["avg_" + key] = value
|
|
||||||
update_eval_values["avg_step_time"] = step_time
|
|
||||||
self.keep_avg_eval.update_values(update_eval_values)
|
|
||||||
|
|
||||||
if self.config.print_eval:
|
|
||||||
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
|
|
||||||
return outputs, loss_dict
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def eval_epoch(self) -> None:
|
def eval_epoch(self) -> None:
|
||||||
self.model.eval()
|
pass
|
||||||
self.c_logger.print_eval_start()
|
|
||||||
loader_start_time = time.time()
|
|
||||||
batch = None
|
|
||||||
for cur_step, batch in enumerate(self.eval_loader):
|
|
||||||
# format data
|
|
||||||
batch = self.format_batch(batch)
|
|
||||||
loader_time = time.time() - loader_start_time
|
|
||||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
|
||||||
outputs, _ = self.eval_step(batch, cur_step)
|
|
||||||
# Plot epoch stats and samples from the last batch.
|
|
||||||
if self.args.rank == 0:
|
|
||||||
if hasattr(self.model, "module"):
|
|
||||||
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
|
|
||||||
else:
|
|
||||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
|
||||||
|
|
||||||
def test_run(
|
@abstractmethod
|
||||||
self,
|
def test_run(self) -> None:
|
||||||
) -> None:
|
pass
|
||||||
print(" | > Synthesizing test sentences.")
|
|
||||||
test_audios = {}
|
|
||||||
test_figures = {}
|
|
||||||
test_sentences = self.config.test_sentences
|
|
||||||
aux_inputs = self._get_aux_inputs()
|
|
||||||
for idx, sen in enumerate(test_sentences):
|
|
||||||
wav, alignment, model_outputs, _ = synthesis(
|
|
||||||
self.model,
|
|
||||||
sen,
|
|
||||||
self.config,
|
|
||||||
self.use_cuda,
|
|
||||||
self.ap,
|
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
|
||||||
d_vector=aux_inputs["d_vector"],
|
|
||||||
style_wav=aux_inputs["style_wav"],
|
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
|
||||||
use_griffin_lim=True,
|
|
||||||
do_trim_silence=False,
|
|
||||||
).values()
|
|
||||||
|
|
||||||
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
|
||||||
self.ap.save_wav(wav, file_path)
|
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
|
||||||
|
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
|
||||||
|
|
||||||
def _get_aux_inputs(self) -> Dict:
|
|
||||||
# setup speaker_id
|
|
||||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
|
||||||
# setup d_vector
|
|
||||||
d_vector = (
|
|
||||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
|
||||||
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
# setup style_mel
|
|
||||||
if self.config.has("gst_style_input"):
|
|
||||||
style_wav = self.config.gst_style_input
|
|
||||||
else:
|
|
||||||
style_wav = None
|
|
||||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
|
||||||
# inicialize GST with zero dict.
|
|
||||||
style_wav = {}
|
|
||||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
|
||||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
|
||||||
style_wav[str(i)] = 0
|
|
||||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
|
||||||
return aux_inputs
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
pass
|
||||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
|
||||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
|
||||||
|
|
||||||
# define data loaders
|
|
||||||
self.train_loader = self.get_train_dataloader(
|
|
||||||
self.config.r,
|
|
||||||
self.ap,
|
|
||||||
self.data_train,
|
|
||||||
verbose=True,
|
|
||||||
speaker_ids=self.speaker_manager.speaker_ids,
|
|
||||||
d_vectors=self.speaker_manager.d_vectors,
|
|
||||||
)
|
|
||||||
self.eval_loader = (
|
|
||||||
self.get_eval_dataloder(
|
|
||||||
self.config.r,
|
|
||||||
self.ap,
|
|
||||||
self.data_train,
|
|
||||||
verbose=True,
|
|
||||||
speaker_ids=self.speaker_manager.speaker_ids,
|
|
||||||
d_vectors=self.speaker_manager.d_vectors,
|
|
||||||
)
|
|
||||||
if self.config.run_eval
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
|
||||||
|
|
||||||
for epoch in range(0, self.config.epochs):
|
|
||||||
self.on_epoch_start()
|
|
||||||
self.keep_avg_train = KeepAverage()
|
|
||||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
|
||||||
self.epochs_done = epoch
|
|
||||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
|
||||||
self.train_epoch()
|
|
||||||
if self.config.run_eval:
|
|
||||||
self.eval_epoch()
|
|
||||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
|
||||||
self.test_run()
|
|
||||||
self.c_logger.print_epoch_end(
|
|
||||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
|
||||||
)
|
|
||||||
self.save_best_model()
|
|
||||||
self.on_epoch_end()
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def save_best_model(self) -> None:
|
def save_best_model(self) -> None:
|
||||||
self.best_loss = save_best_model(
|
pass
|
||||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
|
||||||
self.best_loss,
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
self.total_steps_done,
|
|
||||||
self.epochs_done,
|
|
||||||
self.config.r,
|
|
||||||
self.output_path,
|
|
||||||
self.model_characters,
|
|
||||||
keep_all_best=self.config.keep_all_best,
|
|
||||||
keep_after=self.config.keep_after,
|
|
||||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@abstractmethod
|
||||||
def _setup_logger_config(log_file: str) -> None:
|
def on_epoch_start(self) -> None:
|
||||||
logging.basicConfig(
|
pass
|
||||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
@abstractmethod
|
||||||
if hasattr(self.model, "on_epoch_start"):
|
def on_epoch_end(self) -> None:
|
||||||
self.model.on_epoch_start(self)
|
pass
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_epoch_start"):
|
@abstractmethod
|
||||||
self.criterion.on_epoch_start(self)
|
def on_train_step_start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
if hasattr(self.optimizer, "on_epoch_start"):
|
@abstractmethod
|
||||||
self.optimizer.on_epoch_start(self)
|
def on_train_step_end(self) -> None:
|
||||||
|
pass
|
||||||
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
|
|
||||||
if hasattr(self.model, "on_epoch_end"):
|
|
||||||
self.model.on_epoch_end(self)
|
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_epoch_end"):
|
|
||||||
self.criterion.on_epoch_end(self)
|
|
||||||
|
|
||||||
if hasattr(self.optimizer, "on_epoch_end"):
|
|
||||||
self.optimizer.on_epoch_end(self)
|
|
||||||
|
|
||||||
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
|
|
||||||
if hasattr(self.model, "on_train_step_start"):
|
|
||||||
self.model.on_train_step_start(self)
|
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_train_step_start"):
|
|
||||||
self.criterion.on_train_step_start(self)
|
|
||||||
|
|
||||||
if hasattr(self.optimizer, "on_train_step_start"):
|
|
||||||
self.optimizer.on_train_step_start(self)
|
|
||||||
|
|
||||||
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
|
|
||||||
if hasattr(self.model, "on_train_step_end"):
|
|
||||||
self.model.on_train_step_end(self)
|
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_train_step_end"):
|
|
||||||
self.criterion.on_train_step_end(self)
|
|
||||||
|
|
||||||
if hasattr(self.optimizer, "on_train_step_end"):
|
|
||||||
self.optimizer.on_train_step_end(self)
|
|
||||||
|
|
|
@ -0,0 +1,709 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from argparse import Namespace
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
# DISTRIBUTED
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from TTS.trainer import TrainerAbstract
|
||||||
|
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||||
|
from TTS.tts.layers import setup_loss
|
||||||
|
from TTS.tts.models import setup_model
|
||||||
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
from TTS.tts.utils.text.symbols import make_symbols
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.distribute import init_distributed
|
||||||
|
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
|
||||||
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||||
|
from TTS.utils.training import check_update, setup_torch_training_env
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
||||||
|
|
||||||
|
class TrainerTTS(TrainerAbstract):
|
||||||
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: Union[Coqpit, Namespace],
|
||||||
|
config: Coqpit,
|
||||||
|
c_logger: ConsoleLogger = None,
|
||||||
|
tb_logger: TensorboardLogger = None,
|
||||||
|
model: nn.Module = None,
|
||||||
|
output_path: str = None,
|
||||||
|
) -> None:
|
||||||
|
self.args = args
|
||||||
|
self.config = config
|
||||||
|
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||||
|
if tb_logger is None:
|
||||||
|
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||||
|
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
else:
|
||||||
|
self.tb_logger = tb_logger
|
||||||
|
self.output_path = output_path
|
||||||
|
|
||||||
|
self.total_steps_done = 0
|
||||||
|
self.epochs_done = 0
|
||||||
|
self.restore_step = 0
|
||||||
|
self.best_loss = float("inf")
|
||||||
|
self.train_loader = None
|
||||||
|
self.eval_loader = None
|
||||||
|
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||||
|
|
||||||
|
self.keep_avg_train = None
|
||||||
|
self.keep_avg_eval = None
|
||||||
|
|
||||||
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
|
# model, audio processor, datasets, loss
|
||||||
|
# init audio processor
|
||||||
|
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||||
|
|
||||||
|
# init character processor
|
||||||
|
self.model_characters = self.get_character_processor(self.config)
|
||||||
|
|
||||||
|
# load dataset samples
|
||||||
|
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
||||||
|
|
||||||
|
# default speaker manager
|
||||||
|
self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train)
|
||||||
|
|
||||||
|
# init TTS model
|
||||||
|
if model is not None:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.model = self.get_model(
|
||||||
|
len(self.model_characters),
|
||||||
|
self.speaker_manager.num_speakers,
|
||||||
|
self.config,
|
||||||
|
self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup criterion
|
||||||
|
self.criterion = self.get_criterion(self.config)
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
init_distributed(
|
||||||
|
args.rank,
|
||||||
|
self.num_gpus,
|
||||||
|
args.group_id,
|
||||||
|
self.config.distributed_backend,
|
||||||
|
self.config.distributed_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_cuda:
|
||||||
|
self.model.cuda()
|
||||||
|
self.criterion.cuda()
|
||||||
|
|
||||||
|
# scalers for mixed precision training
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||||
|
|
||||||
|
# setup optimizer
|
||||||
|
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||||
|
|
||||||
|
if self.args.restore_path:
|
||||||
|
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
||||||
|
self.config, args.restore_path, self.model, self.optimizer, self.scaler
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup scheduler
|
||||||
|
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||||
|
|
||||||
|
# count model size
|
||||||
|
num_params = count_parameters(self.model)
|
||||||
|
print("\n > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module:
|
||||||
|
model = setup_model(num_chars, num_speakers, config, d_vector_dim)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
||||||
|
optimizer_name = config.optimizer
|
||||||
|
optimizer_params = config.optimizer_params
|
||||||
|
if optimizer_name.lower() == "radam":
|
||||||
|
module = importlib.import_module("TTS.utils.radam")
|
||||||
|
optimizer = getattr(module, "RAdam")
|
||||||
|
else:
|
||||||
|
optimizer = getattr(torch.optim, optimizer_name)
|
||||||
|
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_character_processor(config: Coqpit) -> str:
|
||||||
|
# setup custom characters if set in config file.
|
||||||
|
# TODO: implement CharacterProcessor
|
||||||
|
if config.characters is not None:
|
||||||
|
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||||
|
else:
|
||||||
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
|
model_characters = phonemes if config.use_phonemes else symbols
|
||||||
|
return model_characters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_speaker_manager(
|
||||||
|
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
||||||
|
) -> SpeakerManager:
|
||||||
|
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
|
||||||
|
return speaker_manager
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scheduler(
|
||||||
|
config: Coqpit, optimizer: torch.optim.Optimizer
|
||||||
|
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||||
|
lr_scheduler = config.lr_scheduler
|
||||||
|
lr_scheduler_params = config.lr_scheduler_params
|
||||||
|
if lr_scheduler is None:
|
||||||
|
return None
|
||||||
|
if lr_scheduler.lower() == "noamlr":
|
||||||
|
from TTS.utils.training import NoamLR
|
||||||
|
|
||||||
|
scheduler = NoamLR
|
||||||
|
else:
|
||||||
|
scheduler = getattr(torch.optim, lr_scheduler)
|
||||||
|
return scheduler(optimizer, **lr_scheduler_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_criterion(config: Coqpit) -> nn.Module:
|
||||||
|
return setup_loss(config)
|
||||||
|
|
||||||
|
def restore_model(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
restore_path: str,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scaler: torch.cuda.amp.GradScaler = None,
|
||||||
|
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
||||||
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
|
checkpoint = torch.load(restore_path)
|
||||||
|
try:
|
||||||
|
print(" > Restoring Model...")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
print(" > Restoring Optimizer...")
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||||
|
if "scaler" in checkpoint and config.mixed_precision:
|
||||||
|
print(" > Restoring AMP Scaler...")
|
||||||
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
|
except (KeyError, RuntimeError):
|
||||||
|
print(" > Partial model initialization...")
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
del model_dict
|
||||||
|
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group["lr"] = self.config.lr
|
||||||
|
print(
|
||||||
|
" > Model restored from step %d" % checkpoint["step"],
|
||||||
|
)
|
||||||
|
restore_step = checkpoint["step"]
|
||||||
|
return model, optimizer, scaler, restore_step
|
||||||
|
|
||||||
|
def _get_loader(
|
||||||
|
self,
|
||||||
|
r: int,
|
||||||
|
ap: AudioProcessor,
|
||||||
|
is_eval: bool,
|
||||||
|
data_items: List,
|
||||||
|
verbose: bool,
|
||||||
|
speaker_ids: Union[Dict, List],
|
||||||
|
d_vectors: Union[Dict, List],
|
||||||
|
) -> DataLoader:
|
||||||
|
if is_eval and not self.config.run_eval:
|
||||||
|
loader = None
|
||||||
|
else:
|
||||||
|
dataset = TTSDataset(
|
||||||
|
outputs_per_step=r,
|
||||||
|
text_cleaner=self.config.text_cleaner,
|
||||||
|
compute_linear_spec=self.config.model.lower() == "tacotron",
|
||||||
|
meta_data=data_items,
|
||||||
|
ap=ap,
|
||||||
|
tp=self.config.characters,
|
||||||
|
add_blank=self.config["add_blank"],
|
||||||
|
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
|
||||||
|
min_seq_len=self.config.min_seq_len,
|
||||||
|
max_seq_len=self.config.max_seq_len,
|
||||||
|
phoneme_cache_path=self.config.phoneme_cache_path,
|
||||||
|
use_phonemes=self.config.use_phonemes,
|
||||||
|
phoneme_language=self.config.phoneme_language,
|
||||||
|
enable_eos_bos=self.config.enable_eos_bos_chars,
|
||||||
|
use_noise_augment=not is_eval,
|
||||||
|
verbose=verbose,
|
||||||
|
speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None,
|
||||||
|
d_vector_mapping=d_vectors
|
||||||
|
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
||||||
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
|
dataset.compute_input_seq(self.config.num_loader_workers)
|
||||||
|
dataset.sort_items()
|
||||||
|
|
||||||
|
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def get_train_dataloader(
|
||||||
|
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||||
|
) -> DataLoader:
|
||||||
|
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
|
||||||
|
|
||||||
|
def get_eval_dataloder(
|
||||||
|
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||||
|
) -> DataLoader:
|
||||||
|
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
|
||||||
|
|
||||||
|
def format_batch(self, batch: List) -> Dict:
|
||||||
|
# setup input batch
|
||||||
|
text_input = batch[0]
|
||||||
|
text_lengths = batch[1]
|
||||||
|
speaker_names = batch[2]
|
||||||
|
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||||
|
mel_input = batch[4]
|
||||||
|
mel_lengths = batch[5]
|
||||||
|
stop_targets = batch[6]
|
||||||
|
item_idx = batch[7]
|
||||||
|
d_vectors = batch[8]
|
||||||
|
speaker_ids = batch[9]
|
||||||
|
attn_mask = batch[10]
|
||||||
|
max_text_length = torch.max(text_lengths.float())
|
||||||
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
# compute durations from attention masks
|
||||||
|
durations = None
|
||||||
|
if attn_mask is not None:
|
||||||
|
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||||
|
for idx, am in enumerate(attn_mask):
|
||||||
|
# compute raw durations
|
||||||
|
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||||
|
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||||
|
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||||
|
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||||
|
dur[c_idxs] = counts
|
||||||
|
# smooth the durations and set any 0 duration to 1
|
||||||
|
# by cutting off from the largest duration indeces.
|
||||||
|
extra_frames = dur.sum() - mel_lengths[idx]
|
||||||
|
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||||
|
dur[largest_idxs] -= 1
|
||||||
|
assert (
|
||||||
|
dur.sum() == mel_lengths[idx]
|
||||||
|
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||||
|
durations[idx, : text_lengths[idx]] = dur
|
||||||
|
|
||||||
|
# set stop targets view, we predict a single stop token per iteration.
|
||||||
|
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||||
|
|
||||||
|
# dispatch batch to GPU
|
||||||
|
if self.use_cuda:
|
||||||
|
text_input = to_cuda(text_input)
|
||||||
|
text_lengths = to_cuda(text_lengths)
|
||||||
|
mel_input = to_cuda(mel_input)
|
||||||
|
mel_lengths = to_cuda(mel_lengths)
|
||||||
|
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
|
||||||
|
stop_targets = to_cuda(stop_targets)
|
||||||
|
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
|
||||||
|
durations = to_cuda(durations) if attn_mask is not None else None
|
||||||
|
if speaker_ids is not None:
|
||||||
|
speaker_ids = to_cuda(speaker_ids)
|
||||||
|
if d_vectors is not None:
|
||||||
|
d_vectors = to_cuda(d_vectors)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text_input": text_input,
|
||||||
|
"text_lengths": text_lengths,
|
||||||
|
"speaker_names": speaker_names,
|
||||||
|
"mel_input": mel_input,
|
||||||
|
"mel_lengths": mel_lengths,
|
||||||
|
"linear_input": linear_input,
|
||||||
|
"stop_targets": stop_targets,
|
||||||
|
"attn_mask": attn_mask,
|
||||||
|
"durations": durations,
|
||||||
|
"speaker_ids": speaker_ids,
|
||||||
|
"d_vectors": d_vectors,
|
||||||
|
"max_text_length": max_text_length,
|
||||||
|
"max_spec_length": max_spec_length,
|
||||||
|
"item_idx": item_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
return self.model.module.train_step(batch, criterion)
|
||||||
|
return self.model.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
|
self.on_train_step_start()
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
# format data
|
||||||
|
batch = self.format_batch(batch)
|
||||||
|
loader_time = time.time() - loader_start_time
|
||||||
|
|
||||||
|
# zero-out optimizer
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||||
|
outputs, loss_dict = self._train_step(batch, self.criterion)
|
||||||
|
|
||||||
|
# check nan loss
|
||||||
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
|
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
if self.config.mixed_precision:
|
||||||
|
# model optimizer step in mixed precision mode
|
||||||
|
self.scaler.scale(loss_dict["loss"]).backward()
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
# main model optimizer step
|
||||||
|
loss_dict["loss"].backward()
|
||||||
|
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
|
# setup lr
|
||||||
|
if self.config.lr_scheduler:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# detach loss values
|
||||||
|
loss_dict_new = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict_new[key] = value
|
||||||
|
else:
|
||||||
|
loss_dict_new[key] = value.item()
|
||||||
|
loss_dict = loss_dict_new
|
||||||
|
|
||||||
|
# update avg stats
|
||||||
|
update_train_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_train_values["avg_" + key] = value
|
||||||
|
update_train_values["avg_loader_time"] = loader_time
|
||||||
|
update_train_values["avg_step_time"] = step_time
|
||||||
|
self.keep_avg_train.update_values(update_train_values)
|
||||||
|
|
||||||
|
# print training progress
|
||||||
|
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
|
if self.total_steps_done % self.config.print_step == 0:
|
||||||
|
log_dict = {
|
||||||
|
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
|
||||||
|
"max_text_length": [batch["max_text_length"], 1],
|
||||||
|
"step_time": [step_time, 4],
|
||||||
|
"loader_time": [loader_time, 2],
|
||||||
|
"current_lr": current_lr,
|
||||||
|
}
|
||||||
|
self.c_logger.print_train_step(
|
||||||
|
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.rank == 0:
|
||||||
|
# Plot Training Iter Stats
|
||||||
|
# reduce TB load
|
||||||
|
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||||
|
iter_stats = {
|
||||||
|
"lr": current_lr,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"step_time": step_time,
|
||||||
|
}
|
||||||
|
iter_stats.update(loss_dict)
|
||||||
|
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||||
|
|
||||||
|
if self.total_steps_done % self.config.save_step == 0:
|
||||||
|
if self.config.checkpoint:
|
||||||
|
# save model
|
||||||
|
save_checkpoint(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.total_steps_done,
|
||||||
|
self.epochs_done,
|
||||||
|
self.config.r,
|
||||||
|
self.output_path,
|
||||||
|
model_loss=loss_dict["loss"],
|
||||||
|
characters=self.model_characters,
|
||||||
|
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||||
|
)
|
||||||
|
# training visualizations
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
||||||
|
else:
|
||||||
|
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||||
|
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||||
|
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||||
|
self.total_steps_done += 1
|
||||||
|
self.on_train_step_end()
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_epoch(self) -> None:
|
||||||
|
self.model.train()
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
if self.use_cuda:
|
||||||
|
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||||
|
else:
|
||||||
|
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||||
|
self.c_logger.print_train_start()
|
||||||
|
loader_start_time = time.time()
|
||||||
|
for cur_step, batch in enumerate(self.train_loader):
|
||||||
|
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||||
|
epoch_time = time.time() - epoch_start_time
|
||||||
|
# Plot self.epochs_done Stats
|
||||||
|
if self.args.rank == 0:
|
||||||
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
|
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||||
|
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||||
|
if self.config.tb_model_param_stats:
|
||||||
|
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||||
|
|
||||||
|
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
return self.model.module.eval_step(batch, self.criterion)
|
||||||
|
return self.model.eval_step(batch, self.criterion)
|
||||||
|
|
||||||
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
|
with torch.no_grad():
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||||
|
outputs, loss_dict = self._eval_step(batch)
|
||||||
|
|
||||||
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
|
# detach loss values
|
||||||
|
loss_dict_new = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict_new[key] = value
|
||||||
|
else:
|
||||||
|
loss_dict_new[key] = value.item()
|
||||||
|
loss_dict = loss_dict_new
|
||||||
|
|
||||||
|
# update avg stats
|
||||||
|
update_eval_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_eval_values["avg_" + key] = value
|
||||||
|
update_eval_values["avg_step_time"] = step_time
|
||||||
|
self.keep_avg_eval.update_values(update_eval_values)
|
||||||
|
|
||||||
|
if self.config.print_eval:
|
||||||
|
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def eval_epoch(self) -> None:
|
||||||
|
self.model.eval()
|
||||||
|
self.c_logger.print_eval_start()
|
||||||
|
loader_start_time = time.time()
|
||||||
|
batch = None
|
||||||
|
for cur_step, batch in enumerate(self.eval_loader):
|
||||||
|
# format data
|
||||||
|
batch = self.format_batch(batch)
|
||||||
|
loader_time = time.time() - loader_start_time
|
||||||
|
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||||
|
outputs, _ = self.eval_step(batch, cur_step)
|
||||||
|
# Plot epoch stats and samples from the last batch.
|
||||||
|
if self.args.rank == 0:
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
|
||||||
|
else:
|
||||||
|
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||||
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
|
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||||
|
|
||||||
|
def test_run(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self._get_aux_inputs()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
wav, alignment, model_outputs, _ = synthesis(
|
||||||
|
self.model,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
self.use_cuda,
|
||||||
|
self.ap,
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
).values()
|
||||||
|
|
||||||
|
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||||
|
self.ap.save_wav(wav, file_path)
|
||||||
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||||
|
|
||||||
|
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
||||||
|
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||||
|
|
||||||
|
def _get_aux_inputs(self) -> Dict:
|
||||||
|
# setup speaker_id
|
||||||
|
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||||
|
# setup d_vector
|
||||||
|
d_vector = (
|
||||||
|
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
||||||
|
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# setup style_mel
|
||||||
|
if self.config.has("gst_style_input"):
|
||||||
|
style_wav = self.config.gst_style_input
|
||||||
|
else:
|
||||||
|
style_wav = None
|
||||||
|
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||||
|
# inicialize GST with zero dict.
|
||||||
|
style_wav = {}
|
||||||
|
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||||
|
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||||
|
style_wav[str(i)] = 0
|
||||||
|
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||||
|
return aux_inputs
|
||||||
|
|
||||||
|
def fit(self) -> None:
|
||||||
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
|
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||||
|
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||||
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
|
# define data loaders
|
||||||
|
self.train_loader = self.get_train_dataloader(
|
||||||
|
self.config.r,
|
||||||
|
self.ap,
|
||||||
|
self.data_train,
|
||||||
|
verbose=True,
|
||||||
|
speaker_ids=self.speaker_manager.speaker_ids,
|
||||||
|
d_vectors=self.speaker_manager.d_vectors,
|
||||||
|
)
|
||||||
|
self.eval_loader = (
|
||||||
|
self.get_eval_dataloder(
|
||||||
|
self.config.r,
|
||||||
|
self.ap,
|
||||||
|
self.data_train,
|
||||||
|
verbose=True,
|
||||||
|
speaker_ids=self.speaker_manager.speaker_ids,
|
||||||
|
d_vectors=self.speaker_manager.d_vectors,
|
||||||
|
)
|
||||||
|
if self.config.run_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.total_steps_done = self.restore_step
|
||||||
|
|
||||||
|
for epoch in range(0, self.config.epochs):
|
||||||
|
self.on_epoch_start()
|
||||||
|
self.keep_avg_train = KeepAverage()
|
||||||
|
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||||
|
self.epochs_done = epoch
|
||||||
|
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||||
|
self.train_epoch()
|
||||||
|
if self.config.run_eval:
|
||||||
|
self.eval_epoch()
|
||||||
|
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||||
|
self.test_run()
|
||||||
|
self.c_logger.print_epoch_end(
|
||||||
|
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||||
|
)
|
||||||
|
self.save_best_model()
|
||||||
|
self.on_epoch_end()
|
||||||
|
|
||||||
|
def save_best_model(self) -> None:
|
||||||
|
self.best_loss = save_best_model(
|
||||||
|
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||||
|
self.best_loss,
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.total_steps_done,
|
||||||
|
self.epochs_done,
|
||||||
|
self.config.r,
|
||||||
|
self.output_path,
|
||||||
|
self.model_characters,
|
||||||
|
keep_all_best=self.config.keep_all_best,
|
||||||
|
keep_after=self.config.keep_after,
|
||||||
|
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _setup_logger_config(log_file: str) -> None:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
||||||
|
if hasattr(self.model, "on_epoch_start"):
|
||||||
|
self.model.on_epoch_start(self)
|
||||||
|
|
||||||
|
if hasattr(self.criterion, "on_epoch_start"):
|
||||||
|
self.criterion.on_epoch_start(self)
|
||||||
|
|
||||||
|
if hasattr(self.optimizer, "on_epoch_start"):
|
||||||
|
self.optimizer.on_epoch_start(self)
|
||||||
|
|
||||||
|
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
|
||||||
|
if hasattr(self.model, "on_epoch_end"):
|
||||||
|
self.model.on_epoch_end(self)
|
||||||
|
|
||||||
|
if hasattr(self.criterion, "on_epoch_end"):
|
||||||
|
self.criterion.on_epoch_end(self)
|
||||||
|
|
||||||
|
if hasattr(self.optimizer, "on_epoch_end"):
|
||||||
|
self.optimizer.on_epoch_end(self)
|
||||||
|
|
||||||
|
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
|
||||||
|
if hasattr(self.model, "on_train_step_start"):
|
||||||
|
self.model.on_train_step_start(self)
|
||||||
|
|
||||||
|
if hasattr(self.criterion, "on_train_step_start"):
|
||||||
|
self.criterion.on_train_step_start(self)
|
||||||
|
|
||||||
|
if hasattr(self.optimizer, "on_train_step_start"):
|
||||||
|
self.optimizer.on_train_step_start(self)
|
||||||
|
|
||||||
|
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
|
||||||
|
if hasattr(self.model, "on_train_step_end"):
|
||||||
|
self.model.on_train_step_end(self)
|
||||||
|
|
||||||
|
if hasattr(self.criterion, "on_train_step_end"):
|
||||||
|
self.criterion.on_train_step_end(self)
|
||||||
|
|
||||||
|
if hasattr(self.optimizer, "on_train_step_end"):
|
||||||
|
self.optimizer.on_train_step_end(self)
|
Loading…
Reference in New Issue