Implement BaseTTSE2E

This commit is contained in:
Eren Gölge 2022-04-04 09:43:15 +02:00
parent 29216ff907
commit f1b034c8b0
1 changed files with 25 additions and 2 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
@ -107,7 +108,9 @@ class BaseTTS(BaseTrainerModel):
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict: def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`""" """Prepare and return `aux_input` used by `forward()`
If not overridden, this function returns a dictionary with None values"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_setences(self, sentence_info): def get_aux_input_from_test_setences(self, sentence_info):
@ -318,7 +321,9 @@ class BaseTTS(BaseTrainerModel):
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping
if get_from_config_or_model_args(config, "use_d_vector_file")
else None,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest, start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping, language_id_mapping=language_id_mapping,
@ -423,3 +428,21 @@ class BaseTTS(BaseTrainerModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.") print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.") print(" > `language_ids_file` is updated in the config.json.")
class BaseTTSE2E(BaseTTS):
def _set_model_args(self, config: Coqpit):
self.config = config
if "Config" in config.__class__.__name__:
num_chars = (
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
)
self.config.model_args.num_chars = num_chars
self.config.num_chars = num_chars
self.args = config.model_args
self.args.num_chars = num_chars
elif "Args" in config.__class__.__name__:
self.args = config
self.args.num_chars = self.args.num_chars
else:
raise ValueError("config must be either a *Config or *Args")