mirror of https://github.com/coqui-ai/TTS.git
Implement BaseTTSE2E
This commit is contained in:
parent
29216ff907
commit
f1b034c8b0
|
@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
|
from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
|
@ -107,7 +108,9 @@ class BaseTTS(BaseTrainerModel):
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
def get_aux_input(self, **kwargs) -> Dict:
|
def get_aux_input(self, **kwargs) -> Dict:
|
||||||
"""Prepare and return `aux_input` used by `forward()`"""
|
"""Prepare and return `aux_input` used by `forward()`
|
||||||
|
|
||||||
|
If not overridden, this function returns a dictionary with None values"""
|
||||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||||
|
|
||||||
def get_aux_input_from_test_setences(self, sentence_info):
|
def get_aux_input_from_test_setences(self, sentence_info):
|
||||||
|
@ -318,7 +321,9 @@ class BaseTTS(BaseTrainerModel):
|
||||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
d_vector_mapping=d_vector_mapping
|
||||||
|
if get_from_config_or_model_args(config, "use_d_vector_file")
|
||||||
|
else None,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
start_by_longest=config.start_by_longest,
|
start_by_longest=config.start_by_longest,
|
||||||
language_id_mapping=language_id_mapping,
|
language_id_mapping=language_id_mapping,
|
||||||
|
@ -423,3 +428,21 @@ class BaseTTS(BaseTrainerModel):
|
||||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||||
print(" > `language_ids_file` is updated in the config.json.")
|
print(" > `language_ids_file` is updated in the config.json.")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTTSE2E(BaseTTS):
|
||||||
|
def _set_model_args(self, config: Coqpit):
|
||||||
|
self.config = config
|
||||||
|
if "Config" in config.__class__.__name__:
|
||||||
|
num_chars = (
|
||||||
|
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||||
|
)
|
||||||
|
self.config.model_args.num_chars = num_chars
|
||||||
|
self.config.num_chars = num_chars
|
||||||
|
self.args = config.model_args
|
||||||
|
self.args.num_chars = num_chars
|
||||||
|
elif "Args" in config.__class__.__name__:
|
||||||
|
self.args = config
|
||||||
|
self.args.num_chars = self.args.num_chars
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
Loading…
Reference in New Issue