typing annotation for the trainer

This commit is contained in:
Eren Gölge 2021-05-28 13:33:58 +02:00
parent 5f07315722
commit 891631ab47
1 changed files with 151 additions and 124 deletions

View File

@ -65,12 +65,12 @@ class TrainerTTS:
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(self, def __init__(self,
args, args: Union[Coqpit, Namespace],
config, config: Coqpit,
c_logger, c_logger: ConsoleLogger,
tb_logger, tb_logger: TensorboardLogger,
model=None, model: nn.Module = None,
output_path=None): output_path: str = None) -> None:
self.args = args self.args = args
self.config = config self.config = config
self.c_logger = c_logger self.c_logger = c_logger
@ -88,43 +88,52 @@ class TrainerTTS:
self.keep_avg_train = None self.keep_avg_train = None
self.keep_avg_eval = None self.keep_avg_eval = None
log_file = os.path.join(self.output_path,
f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# model, audio processor, datasets, loss # model, audio processor, datasets, loss
# init audio processor # init audio processor
self.ap = AudioProcessor(**config.audio.to_dict()) self.ap = AudioProcessor(**self.config.audio.to_dict())
# init character processor # init character processor
self.model_characters = self.init_character_processor() self.model_characters = self.get_character_processor(self.config)
# load dataset samples # load dataset samples
self.data_train, self.data_eval = load_meta_data(config.datasets) self.data_train, self.data_eval = load_meta_data(self.config.datasets)
# default speaker manager # default speaker manager
self.speaker_manager = self.init_speaker_manager() self.speaker_manager = self.get_speaker_manager(
self.config, args.restore_path, self.config.output_path, self.data_train)
# init TTS model # init TTS model
if model is not None: if model is not None:
self.model = model self.model = model
else: else:
self.model = self.init_model() self.model = self.get_model(
len(self.model_characters), self.speaker_manager.num_speakers,
self.config, self.speaker_manager.x_vector_dim
if self.speaker_manager.x_vectors else None)
# setup criterion # setup criterion
self.criterion = self.init_criterion() self.criterion = self.get_criterion(self.config)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# DISTRUBUTED # DISTRUBUTED
if self.num_gpus > 1: if self.num_gpus > 1:
init_distributed(args.rank, self.num_gpus, args.group_id, init_distributed(args.rank, self.num_gpus, args.group_id,
config.distributed["backend"], self.config.distributed["backend"],
config.distributed["url"]) self.config.distributed["url"])
# scalers for mixed precision training # scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler( self.scaler = torch.cuda.amp.GradScaler(
) if config.mixed_precision else None ) if self.config.mixed_precision and self.use_cuda else None
# setup optimizer # setup optimizer
self.optimizer = self.init_optimizer(self.model) self.optimizer = self.get_optimizer(self.model, self.config)
# setup scheduler
self.scheduler = self.init_scheduler(self.config, self.optimizer)
if self.args.restore_path: if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
@ -144,64 +153,66 @@ class TrainerTTS:
logging.info("\n > Model has {} parameters".format(num_params), logging.info("\n > Model has {} parameters".format(num_params),
flush=True) flush=True)
def init_model(self): @staticmethod
model = setup_model( def get_model(num_chars: int, num_speakers: int, config: Coqpit,
len(self.model_characters), x_vector_dim: int) -> nn.Module:
self.speaker_manager.num_speakers, model = setup_model(num_chars, num_speakers, config, x_vector_dim)
self.config,
self.speaker_manager.x_vector_dim
if self.speaker_manager.x_vectors else None,
)
return model return model
def init_optimizer(self, model): @staticmethod
optimizer_name = self.config.optimizer def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
optimizer_params = self.config.optimizer_params optimizer_name = config.optimizer
optimizer_params = config.optimizer_params
if optimizer_name.lower() == "radam": if optimizer_name.lower() == "radam":
module = importlib.import_module("TTS.utils.radam") module = importlib.import_module("TTS.utils.radam")
optimizer = getattr(module, "RAdam") optimizer = getattr(module, "RAdam")
else: else:
optimizer = getattr(torch.optim, optimizer_name) optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(), return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
lr=self.config.lr,
**optimizer_params)
def init_character_processor(self): @staticmethod
def get_character_processor(config: Coqpit) -> str:
# setup custom characters if set in config file. # setup custom characters if set in config file.
# TODO: implement CharacterProcessor # TODO: implement CharacterProcessor
if self.config.characters is not None: if config.characters is not None:
symbols, phonemes = make_symbols( symbols, phonemes = make_symbols(**config.characters.to_dict())
**self.config.characters.to_dict())
else: else:
from TTS.tts.utils.text.symbols import symbols, phonemes from TTS.tts.utils.text.symbols import phonemes, symbols
model_characters = phonemes if self.config.use_phonemes else symbols model_characters = phonemes if config.use_phonemes else symbols
return model_characters return model_characters
def init_speaker_manager(self, restore_path: str = "", out_path: str = ""): @staticmethod
def get_speaker_manager(config: Coqpit,
restore_path: str = "",
out_path: str = "",
data_train: List = []) -> SpeakerManager:
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
if restore_path: if config.use_speaker_embedding:
speakers_file = os.path.join(os.path.dirname(restore_path), if restore_path:
"speaker.json") speakers_file = os.path.join(os.path.dirname(restore_path),
if not os.path.exists(speakers_file): "speaker.json")
logging.info( if not os.path.exists(speakers_file):
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" print(
) "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
speakers_file = self.config.external_speaker_embedding_file )
speakers_file = config.external_speaker_embedding_file
if self.config.use_external_speaker_embedding_file: if config.use_external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(speakers_file) speaker_manager.load_x_vectors_file(speakers_file)
else:
speaker_manager.load_ids_file(speakers_file)
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(
config.external_speaker_embedding_file)
else: else:
self.speaker_manage.load_speaker_mapping(speakers_file) speaker_manager.parse_speakers_from_items(data_train)
elif self.config.use_external_speaker_embedding_file and self.config.external_speaker_embedding_file: file_path = os.path.join(out_path, "speakers.json")
speaker_manager.load_x_vectors_file( speaker_manager.save_ids_file(file_path)
self.config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(self.data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path)
return speaker_manager return speaker_manager
def init_scheduler(self, config, optimizer): @staticmethod
def get_scheduler(config: Coqpit,
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
lr_scheduler = config.lr_scheduler lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None: if lr_scheduler is None:
@ -213,17 +224,20 @@ class TrainerTTS:
scheduler = getattr(torch.optim, lr_scheduler) scheduler = getattr(torch.optim, lr_scheduler)
return scheduler(optimizer, **lr_scheduler_params) return scheduler(optimizer, **lr_scheduler_params)
def init_criterion(self): @staticmethod
return setup_loss(self.config) def get_criterion(config: Coqpit) -> nn.Module:
return setup_loss(config)
def restore_model(self, def restore_model(
config, self,
restore_path, config: Coqpit,
model, restore_path: str,
optimizer, model: nn.Module,
scaler=None): optimizer: torch.optim.Optimizer,
logging.info(f" > Restoring from {os.path.basename(restore_path)}...") scaler: torch.cuda.amp.GradScaler = None
checkpoint = torch.load(restore_path, map_location="cpu") ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
try: try:
logging.info(" > Restoring Model...") logging.info(" > Restoring Model...")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
@ -242,20 +256,20 @@ class TrainerTTS:
for group in optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = self.config.lr group["lr"] = self.config.lr
logging.info(" > Model restored from step %d" % checkpoint["step"], print(" > Model restored from step %d" % checkpoint["step"], )
flush=True)
restore_step = checkpoint["step"] restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step return model, optimizer, scaler, restore_step
def _setup_loader(self, r, ap, is_eval, data_items, verbose, def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool,
speaker_mapping): data_items: List, verbose: bool,
speaker_mapping: Union[Dict, List]) -> DataLoader:
if is_eval and not self.config.run_eval: if is_eval and not self.config.run_eval:
loader = None loader = None
else: else:
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=r, outputs_per_step=r,
text_cleaner=self.config.text_cleaner, text_cleaner=self.config.text_cleaner,
compute_linear_spec= 'tacotron' == self.config.model.lower(), compute_linear_spec=self.config.model.lower() == "tacotron",
meta_data=data_items, meta_data=data_items,
ap=ap, ap=ap,
tp=self.config.characters, tp=self.config.characters,
@ -296,17 +310,19 @@ class TrainerTTS:
) )
return loader return loader
def setup_train_dataloader(self, r, ap, data_items, verbose, def get_train_dataloader(self, r: int, ap: AudioProcessor,
speaker_mapping): data_items: List, verbose: bool,
return self._setup_loader(r, ap, False, data_items, verbose, speaker_mapping: Union[List, Dict]) -> DataLoader:
speaker_mapping) return self._get_loader(r, ap, False, data_items, verbose,
speaker_mapping)
def setup_eval_dataloder(self, r, ap, data_items, verbose, def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List,
speaker_mapping): verbose: bool,
return self._setup_loader(r, ap, True, data_items, verbose, speaker_mapping: Union[List, Dict]) -> DataLoader:
speaker_mapping) return self._get_loader(r, ap, True, data_items, verbose,
speaker_mapping)
def format_batch(self, batch): def format_batch(self, batch: List) -> Dict:
# setup input batch # setup input batch
text_input = batch[0] text_input = batch[0]
text_lengths = batch[1] text_lengths = batch[1]
@ -401,7 +417,8 @@ class TrainerTTS:
"item_idx": item_idx "item_idx": item_idx
} }
def train_step(self, batch, batch_n_steps, step, loader_start_time): def train_step(self, batch: Dict, batch_n_steps: int, step: int,
loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start() self.on_train_step_start()
step_start_time = time.time() step_start_time = time.time()
@ -515,7 +532,7 @@ class TrainerTTS:
self.on_train_step_end() self.on_train_step_end()
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self): def train_epoch(self) -> None:
self.model.train() self.model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
if self.use_cuda: if self.use_cuda:
@ -541,7 +558,7 @@ class TrainerTTS:
self.tb_logger.tb_model_weights(self.model, self.tb_logger.tb_model_weights(self.model,
self.total_steps_done) self.total_steps_done)
def eval_step(self, batch, step): def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad(): with torch.no_grad():
step_start_time = time.time() step_start_time = time.time()
@ -572,17 +589,11 @@ class TrainerTTS:
self.keep_avg_eval.avg_values) self.keep_avg_eval.avg_values)
return outputs, loss_dict return outputs, loss_dict
def eval_epoch(self): def eval_epoch(self) -> None:
self.model.eval() self.model.eval()
if self.use_cuda:
batch_num_steps = int(
len(self.train_loader.dataset) /
(self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(
len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_eval_start() self.c_logger.print_eval_start()
loader_start_time = time.time() loader_start_time = time.time()
batch = None
for cur_step, batch in enumerate(self.eval_loader): for cur_step, batch in enumerate(self.eval_loader):
# format data # format data
batch = self.format_batch(batch) batch = self.format_batch(batch)
@ -597,8 +608,8 @@ class TrainerTTS:
{"EvalAudio": eval_audios}, {"EvalAudio": eval_audios},
self.ap.sample_rate) self.ap.sample_rate)
def test_run(self, ): def test_run(self, ) -> None:
logging.info(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -618,9 +629,11 @@ class TrainerTTS:
do_trim_silence=False, do_trim_silence=False,
).values() ).values()
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) file_path = os.path.join(self.output_audio_path,
str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path) self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram( test_figures["{}-prediction".format(idx)] = plot_spectrogram(
@ -629,16 +642,17 @@ class TrainerTTS:
alignment, output_fig=False) alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
self.config.audio["sample_rate"]) self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_cond_inputs(self): def _get_cond_inputs(self) -> Dict:
# setup speaker_id # setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None speaker_id = 0 if self.config.use_speaker_embedding else None
# setup x_vector # setup x_vector
x_vector = self.speaker_manager.get_x_vectors_by_speaker( x_vector = (self.speaker_manager.get_x_vectors_by_speaker(
self.speaker_manager.speaker_ids[0] self.speaker_manager.speaker_ids[0])
) if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding else None if self.config.use_external_speaker_embedding_file
and self.config.use_speaker_embedding else None)
# setup style_mel # setup style_mel
if self.config.has('gst_style_input'): if self.config.has('gst_style_input'):
style_wav = self.config.gst_style_input style_wav = self.config.gst_style_input
@ -647,35 +661,40 @@ class TrainerTTS:
if style_wav is None and 'use_gst' in self.config and self.config.use_gst: if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") print(
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
)
for i in range(self.config.gst["gst_num_style_tokens"]): for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0 style_wav[str(i)] = 0
cond_inputs = {'speaker_id': speaker_id, 'style_wav': style_wav, 'x_vector': x_vector} cond_inputs = {
"speaker_id": speaker_id,
"style_wav": style_wav,
"x_vector": x_vector
}
return cond_inputs return cond_inputs
def fit(self): def fit(self) -> None:
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
logging.info(" > Restoring best loss from " print(" > Restoring best loss from "
f"{os.path.basename(self.args.best_path)} ...") f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, self.best_loss = torch.load(self.args.best_path,
map_location="cpu")["model_loss"] map_location="cpu")["model_loss"]
logging.info( print(f" > Starting with loaded last best loss {self.best_loss}.")
f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders # define data loaders
self.train_loader = self.setup_train_dataloader( self.train_loader = self.get_train_dataloader(
self.config.r, self.config.r,
self.ap, self.ap,
self.data_train, self.data_train,
verbose=True, verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids) speaker_mapping=self.speaker_manager.speaker_ids)
self.eval_loader = self.setup_eval_dataloder( self.eval_loader = (self.get_eval_dataloder(
self.config.r, self.config.r,
self.ap, self.ap,
self.data_train, self.data_train,
verbose=True, verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids speaker_mapping=self.speaker_manager.speaker_ids)
) if self.config.run_eval else None if self.config.run_eval else None)
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
@ -697,10 +716,10 @@ class TrainerTTS:
self.save_best_model() self.save_best_model()
self.on_epoch_end() self.on_epoch_end()
def save_best_model(self): def save_best_model(self) -> None:
self.best_loss = save_best_model( self.best_loss = save_best_model(
self.keep_avg_eval['avg_loss'] self.keep_avg_eval["avg_loss"]
if self.keep_avg_eval else self.keep_avg_train['avg_loss'], if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.best_loss, self.best_loss,
self.model, self.model,
self.optimizer, self.optimizer,
@ -715,8 +734,16 @@ class TrainerTTS:
if self.config.mixed_precision else None, if self.config.mixed_precision else None,
) )
def on_epoch_start(self): @staticmethod
if hasattr(self.model, 'on_epoch_start'): def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO,
format="",
handlers=[logging.FileHandler(log_file),
logging.StreamHandler()])
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"):
self.model.on_epoch_start(self) self.model.on_epoch_start(self)
if hasattr(self.criterion, "on_epoch_start"): if hasattr(self.criterion, "on_epoch_start"):
@ -725,8 +752,8 @@ class TrainerTTS:
if hasattr(self.optimizer, "on_epoch_start"): if hasattr(self.optimizer, "on_epoch_start"):
self.optimizer.on_epoch_start(self) self.optimizer.on_epoch_start(self)
def on_epoch_end(self): def on_epoch_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"): if hasattr(self.model, "on_epoch_end"):
self.model.on_epoch_end(self) self.model.on_epoch_end(self)
if hasattr(self.criterion, "on_epoch_end"): if hasattr(self.criterion, "on_epoch_end"):
@ -735,8 +762,8 @@ class TrainerTTS:
if hasattr(self.optimizer, "on_epoch_end"): if hasattr(self.optimizer, "on_epoch_end"):
self.optimizer.on_epoch_end(self) self.optimizer.on_epoch_end(self)
def on_train_step_start(self): def on_train_step_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"): if hasattr(self.model, "on_train_step_start"):
self.model.on_train_step_start(self) self.model.on_train_step_start(self)
if hasattr(self.criterion, "on_train_step_start"): if hasattr(self.criterion, "on_train_step_start"):
@ -745,7 +772,7 @@ class TrainerTTS:
if hasattr(self.optimizer, "on_train_step_start"): if hasattr(self.optimizer, "on_train_step_start"):
self.optimizer.on_train_step_start(self) self.optimizer.on_train_step_start(self)
def on_train_step_end(self): def on_train_step_end(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_train_step_end"): if hasattr(self.model, "on_train_step_end"):
self.model.on_train_step_end(self) self.model.on_train_step_end(self)