diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c71872d3..6069eb26 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.data import get_length_balancer_weights @@ -108,7 +109,9 @@ class BaseTTS(BaseTrainerModel): self.speaker_embedding.weight.data.normal_(0, 0.3) def get_aux_input(self, **kwargs) -> Dict: - """Prepare and return `aux_input` used by `forward()`""" + """Prepare and return `aux_input` used by `forward()` + + If not overridden, this function returns a dictionary with None values""" return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} def get_aux_input_from_test_setences(self, sentence_info): @@ -323,7 +326,9 @@ class BaseTTS(BaseTrainerModel): use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + d_vector_mapping=d_vector_mapping + if get_from_config_or_model_args(config, "use_d_vector_file") + else None, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, @@ -428,3 +433,21 @@ class BaseTTS(BaseTrainerModel): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `language_ids.json` is saved to {output_path}.") print(" > `language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args")