mirror of https://github.com/coqui-ai/TTS.git
Implement BaseTTSE2E
This commit is contained in:
parent
b16613c5ad
commit
c125024da0
|
@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
|
||||
from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
|
@ -108,7 +109,9 @@ class BaseTTS(BaseTrainerModel):
|
|||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
"""Prepare and return `aux_input` used by `forward()`
|
||||
|
||||
If not overridden, this function returns a dictionary with None values"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_setences(self, sentence_info):
|
||||
|
@ -323,7 +326,9 @@ class BaseTTS(BaseTrainerModel):
|
|||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
d_vector_mapping=d_vector_mapping
|
||||
if get_from_config_or_model_args(config, "use_d_vector_file")
|
||||
else None,
|
||||
tokenizer=self.tokenizer,
|
||||
start_by_longest=config.start_by_longest,
|
||||
language_id_mapping=language_id_mapping,
|
||||
|
@ -428,3 +433,21 @@ class BaseTTS(BaseTrainerModel):
|
|||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
|
||||
|
||||
class BaseTTSE2E(BaseTTS):
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
self.config = config
|
||||
if "Config" in config.__class__.__name__:
|
||||
num_chars = (
|
||||
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
)
|
||||
self.config.model_args.num_chars = num_chars
|
||||
self.config.num_chars = num_chars
|
||||
self.args = config.model_args
|
||||
self.args.num_chars = num_chars
|
||||
elif "Args" in config.__class__.__name__:
|
||||
self.args = config
|
||||
self.args.num_chars = self.args.num_chars
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
|
Loading…
Reference in New Issue