mirror of https://github.com/coqui-ai/TTS.git
trainer-API updates
This commit is contained in:
parent
42554cc711
commit
8def3c87af
|
@ -229,7 +229,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
||||||
if global_step % config.tb_plot_step == 0:
|
if global_step % config.tb_plot_step == 0:
|
||||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
if global_step % config.save_step == 0:
|
if global_step % config.save_step == 0:
|
||||||
if config.checkpoint:
|
if config.checkpoint:
|
||||||
|
|
|
@ -270,7 +270,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
||||||
if global_step % config.tb_plot_step == 0:
|
if global_step % config.tb_plot_step == 0:
|
||||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
if global_step % config.save_step == 0:
|
if global_step % config.save_step == 0:
|
||||||
if config.checkpoint:
|
if config.checkpoint:
|
||||||
|
|
|
@ -256,7 +256,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
||||||
if global_step % config.tb_plot_step == 0:
|
if global_step % config.tb_plot_step == 0:
|
||||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
if global_step % config.save_step == 0:
|
if global_step % config.save_step == 0:
|
||||||
if config.checkpoint:
|
if config.checkpoint:
|
||||||
|
|
|
@ -327,7 +327,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
|
||||||
"step_time": step_time,
|
"step_time": step_time,
|
||||||
}
|
}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
if global_step % config.save_step == 0:
|
if global_step % config.save_step == 0:
|
||||||
if config.checkpoint:
|
if config.checkpoint:
|
||||||
|
|
|
@ -265,7 +265,7 @@ def train(
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
|
|
|
@ -181,7 +181,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
|
|
|
@ -163,7 +163,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
|
|
|
@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
datasets (List[BaseDatasetConfig]):
|
datasets (List[BaseDatasetConfig]):
|
||||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||||
for training.
|
for training.
|
||||||
|
optimizer (str):
|
||||||
|
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||||
|
Defaults to ``.
|
||||||
|
optimizer_params (dict):
|
||||||
|
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||||
|
lr_scheduler (str):
|
||||||
|
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||||
|
`TTS.utils.training`. Defaults to ``.
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||||
|
test_sentences (List[str]):
|
||||||
|
List of sentences to be used at testing. Defaults to '[]'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||||
|
@ -158,3 +170,11 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
add_blank: bool = False
|
add_blank: bool = False
|
||||||
# dataset
|
# dataset
|
||||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||||
|
# optimizer
|
||||||
|
optimizer: str = MISSING
|
||||||
|
optimizer_params: dict = MISSING
|
||||||
|
# scheduler
|
||||||
|
lr_scheduler: str = ''
|
||||||
|
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(default_factory=lambda:[])
|
||||||
|
|
|
@ -78,10 +78,16 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||||
external_speaker_embedding_file (str):
|
external_speaker_embedding_file (str):
|
||||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
noam_schedule (bool):
|
optimizer (str):
|
||||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||||
warmup_steps (int):
|
Defaults to `RAdam`.
|
||||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
optimizer_params (dict):
|
||||||
|
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||||
|
lr_scheduler (str):
|
||||||
|
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||||
|
`TTS.utils.training`. Defaults to `NoamLR`.
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||||
lr (float):
|
lr (float):
|
||||||
Initial learning rate. Defaults to `1e-4`.
|
Initial learning rate. Defaults to `1e-4`.
|
||||||
wd (float):
|
wd (float):
|
||||||
|
@ -152,10 +158,11 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
external_speaker_embedding_file: str = False
|
external_speaker_embedding_file: str = False
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
noam_schedule: bool = False
|
optimizer: str = "RAdam"
|
||||||
warmup_steps: int = 4000
|
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||||
|
lr_scheduler: str = "NoamLR"
|
||||||
|
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
wd: float = 1e-6
|
|
||||||
grad_clip: float = 5.0
|
grad_clip: float = 5.0
|
||||||
seq_len_norm: bool = False
|
seq_len_norm: bool = False
|
||||||
loss_masking: bool = True
|
loss_masking: bool = True
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
from typing import Union, List, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
|
|
||||||
|
|
||||||
def get_speakers(items):
|
def get_speakers(items):
|
||||||
"""Returns a sorted, unique list of speakers in a given dataset."""
|
|
||||||
speakers = {e[2] for e in items}
|
|
||||||
return sorted(speakers)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||||
|
@ -121,26 +119,31 @@ class SpeakerManager:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
|
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
|
||||||
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
|
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
|
||||||
TTS model. Defaults to "".
|
TTS models. Defaults to "".
|
||||||
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
|
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
|
||||||
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
|
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
data_items: List[List[Any]] = None,
|
||||||
x_vectors_file_path: str = "",
|
x_vectors_file_path: str = "",
|
||||||
speaker_id_file_path: str = "",
|
speaker_id_file_path: str = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: str = "",
|
||||||
encoder_config_path: str = "",
|
encoder_config_path: str = "",
|
||||||
):
|
):
|
||||||
|
|
||||||
self.x_vectors = None
|
self.data_items = []
|
||||||
self.speaker_ids = None
|
self.x_vectors = []
|
||||||
self.clip_ids = None
|
self.speaker_ids = []
|
||||||
|
self.clip_ids = []
|
||||||
self.speaker_encoder = None
|
self.speaker_encoder = None
|
||||||
self.speaker_encoder_ap = None
|
self.speaker_encoder_ap = None
|
||||||
|
|
||||||
|
if data_items:
|
||||||
|
self.speaker_ids = self.parse_speakers()
|
||||||
|
|
||||||
if x_vectors_file_path:
|
if x_vectors_file_path:
|
||||||
self.load_x_vectors_file(x_vectors_file_path)
|
self.load_x_vectors_file(x_vectors_file_path)
|
||||||
|
|
||||||
|
@ -169,10 +172,10 @@ class SpeakerManager:
|
||||||
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
|
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
|
||||||
|
|
||||||
def parser_speakers_from_items(self, items: list):
|
def parser_speakers_from_items(self, items: list):
|
||||||
speaker_ids = sorted({item[2] for item in items})
|
speakers = sorted({item[2] for item in items})
|
||||||
self.speaker_ids = speaker_ids
|
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||||
num_speakers = len(speaker_ids)
|
num_speakers = len(self.speaker_ids)
|
||||||
return speaker_ids, num_speakers
|
return self.speaker_ids, num_speakers
|
||||||
|
|
||||||
def save_ids_file(self, file_path: str):
|
def save_ids_file(self, file_path: str):
|
||||||
self._save_json(file_path, self.speaker_ids)
|
self._save_json(file_path, self.speaker_ids)
|
||||||
|
|
|
@ -65,7 +65,7 @@ def basic_cleaners(text):
|
||||||
|
|
||||||
def transliteration_cleaners(text):
|
def transliteration_cleaners(text):
|
||||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||||
text = convert_to_ascii(text)
|
# text = convert_to_ascii(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):
|
||||||
|
|
||||||
def english_cleaners(text):
|
def english_cleaners(text):
|
||||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||||
text = convert_to_ascii(text)
|
# text = convert_to_ascii(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = expand_time_english(text)
|
text = expand_time_english(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
|
@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
|
||||||
def phoneme_cleaners(text):
|
def phoneme_cleaners(text):
|
||||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = convert_to_ascii(text)
|
# text = convert_to_ascii(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
text = replace_symbols(text)
|
text = replace_symbols(text)
|
||||||
text = remove_aux_symbols(text)
|
text = remove_aux_symbols(text)
|
||||||
|
|
|
@ -39,7 +39,7 @@ class TensorboardLogger(object):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def tb_train_iter_stats(self, step, stats):
|
def tb_train_step_stats(self, step, stats):
|
||||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||||
|
|
||||||
def tb_train_epoch_stats(self, step, stats):
|
def tb_train_epoch_stats(self, step, stats):
|
||||||
|
|
|
@ -21,6 +21,7 @@ config = MelganConfig(
|
||||||
print_step=1,
|
print_step=1,
|
||||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
|
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||||
data_path="tests/data/ljspeech",
|
data_path="tests/data/ljspeech",
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue