Implement BaseTTSE2E

This commit is contained in:
Eren Gölge 2022-04-04 09:43:15 +02:00 committed by Eren G??lge
parent b16613c5ad
commit c125024da0
1 changed files with 25 additions and 2 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
@ -108,7 +109,9 @@ class BaseTTS(BaseTrainerModel):
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
"""Prepare and return `aux_input` used by `forward()`
If not overridden, this function returns a dictionary with None values"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_setences(self, sentence_info):
@ -323,7 +326,9 @@ class BaseTTS(BaseTrainerModel):
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
d_vector_mapping=d_vector_mapping
if get_from_config_or_model_args(config, "use_d_vector_file")
else None,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping,
@ -428,3 +433,21 @@ class BaseTTS(BaseTrainerModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
class BaseTTSE2E(BaseTTS):
def _set_model_args(self, config: Coqpit):
self.config = config
if "Config" in config.__class__.__name__:
num_chars = (
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
)
self.config.model_args.num_chars = num_chars
self.config.num_chars = num_chars
self.args = config.model_args
self.args.num_chars = num_chars
elif "Args" in config.__class__.__name__:
self.args = config
self.args.num_chars = self.args.num_chars
else:
raise ValueError("config must be either a *Config or *Args")