trainer-API updates

This commit is contained in:
Eren Gölge 2021-05-20 18:23:53 +02:00
parent 42554cc711
commit 8def3c87af
13 changed files with 62 additions and 31 deletions

View File

@ -229,7 +229,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0: if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0: if global_step % config.save_step == 0:
if config.checkpoint: if config.checkpoint:

View File

@ -270,7 +270,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0: if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0: if global_step % config.save_step == 0:
if config.checkpoint: if config.checkpoint:

View File

@ -256,7 +256,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0: if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0: if global_step % config.save_step == 0:
if config.checkpoint: if config.checkpoint:

View File

@ -327,7 +327,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
"step_time": step_time, "step_time": step_time,
} }
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0: if global_step % config.save_step == 0:
if config.checkpoint: if config.checkpoint:

View File

@ -265,7 +265,7 @@ def train(
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint # save checkpoint
if global_step % c.save_step == 0: if global_step % c.save_step == 0:

View File

@ -181,7 +181,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time} iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint # save checkpoint
if global_step % c.save_step == 0: if global_step % c.save_step == 0:

View File

@ -163,7 +163,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = {"lr": cur_lr, "step_time": step_time} iter_stats = {"lr": cur_lr, "step_time": step_time}
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint # save checkpoint
if global_step % c.save_step == 0: if global_step % c.save_step == 0:

View File

@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig):
datasets (List[BaseDatasetConfig]): datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training. for training.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
""" """
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
@ -158,3 +170,11 @@ class BaseTTSConfig(BaseTrainingConfig):
add_blank: bool = False add_blank: bool = False
# dataset # dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = MISSING
optimizer_params: dict = MISSING
# scheduler
lr_scheduler: str = ''
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda:[])

View File

@ -78,10 +78,16 @@ class TacotronConfig(BaseTTSConfig):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
external_speaker_embedding_file (str): external_speaker_embedding_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None. Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool): optimizer (str):
enable / disable the use of Noam LR scheduler. Defaults to False. Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
warmup_steps (int): Defaults to `RAdam`.
Number of warm-up steps for the Noam scheduler. Defaults 4000. optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to `NoamLR`.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
lr (float): lr (float):
Initial learning rate. Defaults to `1e-4`. Initial learning rate. Defaults to `1e-4`.
wd (float): wd (float):
@ -152,10 +158,11 @@ class TacotronConfig(BaseTTSConfig):
external_speaker_embedding_file: str = False external_speaker_embedding_file: str = False
# optimizer parameters # optimizer parameters
noam_schedule: bool = False optimizer: str = "RAdam"
warmup_steps: int = 4000 optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
lr: float = 1e-4 lr: float = 1e-4
wd: float = 1e-6
grad_clip: float = 5.0 grad_clip: float = 5.0
seq_len_norm: bool = False seq_len_norm: bool = False
loss_masking: bool = True loss_masking: bool = True

View File

@ -1,7 +1,7 @@
import json import json
import os import os
import random import random
from typing import Union from typing import Union, List, Any
import numpy as np import numpy as np
import torch import torch
@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
def get_speakers(items): def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items}
return sorted(speakers)
def parse_speakers(c, args, meta_data_train, OUT_PATH): def parse_speakers(c, args, meta_data_train, OUT_PATH):
@ -121,26 +119,31 @@ class SpeakerManager:
Args: Args:
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
TTS model. Defaults to "". TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
""" """
def __init__( def __init__(
self, self,
data_items: List[List[Any]] = None,
x_vectors_file_path: str = "", x_vectors_file_path: str = "",
speaker_id_file_path: str = "", speaker_id_file_path: str = "",
encoder_model_path: str = "", encoder_model_path: str = "",
encoder_config_path: str = "", encoder_config_path: str = "",
): ):
self.x_vectors = None self.data_items = []
self.speaker_ids = None self.x_vectors = []
self.clip_ids = None self.speaker_ids = []
self.clip_ids = []
self.speaker_encoder = None self.speaker_encoder = None
self.speaker_encoder_ap = None self.speaker_encoder_ap = None
if data_items:
self.speaker_ids = self.parse_speakers()
if x_vectors_file_path: if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path) self.load_x_vectors_file(x_vectors_file_path)
@ -169,10 +172,10 @@ class SpeakerManager:
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
def parser_speakers_from_items(self, items: list): def parser_speakers_from_items(self, items: list):
speaker_ids = sorted({item[2] for item in items}) speakers = sorted({item[2] for item in items})
self.speaker_ids = speaker_ids self.speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids) num_speakers = len(self.speaker_ids)
return speaker_ids, num_speakers return self.speaker_ids, num_speakers
def save_ids_file(self, file_path: str): def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids) self._save_json(file_path, self.speaker_ids)

View File

@ -65,7 +65,7 @@ def basic_cleaners(text):
def transliteration_cleaners(text): def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII.""" """Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text) # text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):
def english_cleaners(text): def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion.""" """Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text) # text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = expand_time_english(text) text = expand_time_english(text)
text = expand_numbers(text) text = expand_numbers(text)
@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
def phoneme_cleaners(text): def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion.""" """Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text) text = expand_numbers(text)
text = convert_to_ascii(text) # text = convert_to_ascii(text)
text = expand_abbreviations(text) text = expand_abbreviations(text)
text = replace_symbols(text) text = replace_symbols(text)
text = remove_aux_symbols(text) text = remove_aux_symbols(text)

View File

@ -39,7 +39,7 @@ class TensorboardLogger(object):
except RuntimeError: except RuntimeError:
traceback.print_exc() traceback.print_exc()
def tb_train_iter_stats(self, step, stats): def tb_train_step_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats): def tb_train_epoch_stats(self, step, stats):

View File

@ -21,6 +21,7 @@ config = MelganConfig(
print_step=1, print_step=1,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
print_eval=True, print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech", data_path="tests/data/ljspeech",
output_path=output_path, output_path=output_path,
) )