mirror of https://github.com/coqui-ai/TTS.git
Update Vits for the new model API
This commit is contained in:
parent
5176ae9e53
commit
001da8afc8
|
@ -278,7 +278,12 @@ class Vits(BaseTTS):
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
ap: "AudioProcessor" = None,
|
||||||
|
tokenizer: "TTSTokenizer" = None,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
language_manager: LanguageManager = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
|
@ -287,8 +292,6 @@ class Vits(BaseTTS):
|
||||||
self.speaker_manager = speaker_manager
|
self.speaker_manager = speaker_manager
|
||||||
self.language_manager = language_manager
|
self.language_manager = language_manager
|
||||||
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
self.init_multilingual(config)
|
self.init_multilingual(config)
|
||||||
|
|
||||||
|
@ -309,6 +312,7 @@ class Vits(BaseTTS):
|
||||||
self.args.num_layers_text_encoder,
|
self.args.num_layers_text_encoder,
|
||||||
self.args.kernel_size_text_encoder,
|
self.args.kernel_size_text_encoder,
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -884,7 +888,7 @@ class Vits(BaseTTS):
|
||||||
return self._log(self.ap, batch, outputs, "eval")
|
return self._log(self.ap, batch, outputs, "eval")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self) -> Tuple[Dict, Dict]:
|
def test_run(self, assets) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -904,7 +908,7 @@ class Vits(BaseTTS):
|
||||||
aux_inputs["text"],
|
aux_inputs["text"],
|
||||||
self.config,
|
self.config,
|
||||||
"cuda" in str(next(self.parameters()).device),
|
"cuda" in str(next(self.parameters()).device),
|
||||||
ap,
|
self.ap,
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
@ -1007,7 +1011,8 @@ class Vits(BaseTTS):
|
||||||
ap = AudioProcessor.init_from_config(config)
|
ap = AudioProcessor.init_from_config(config)
|
||||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager)
|
language_manager = LanguageManager.init_from_config(config)
|
||||||
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||||
|
|
||||||
|
|
||||||
class VitsCharacters(BaseCharacters):
|
class VitsCharacters(BaseCharacters):
|
||||||
|
|
Loading…
Reference in New Issue