mirror of https://github.com/coqui-ai/TTS.git
Minors bug fixes on VITS/YourTTS and inference (#2054)
* Set the right device to the speaker encoder * Bug fix on inference list_language_idxs parameter * Bug fix on speaker encoder resample audio transform
This commit is contained in:
parent
5f5d441ee5
commit
f3b947e706
|
@ -331,7 +331,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
print(
|
print(
|
||||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||||
)
|
)
|
||||||
print(synthesizer.tts_model.language_manager.ids)
|
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# check the arguments against a multi-speaker model.
|
# check the arguments against a multi-speaker model.
|
||||||
|
|
|
@ -721,6 +721,10 @@ class Vits(BaseTTS):
|
||||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
@ -758,17 +762,12 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.speaker_manager.encoder, "audio_config")
|
hasattr(self.speaker_manager.encoder, "audio_config")
|
||||||
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
|
and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"]
|
||||||
):
|
):
|
||||||
self.audio_transform = torchaudio.transforms.Resample(
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
orig_freq=self.audio_config["sample_rate"],
|
orig_freq=self.config.audio.sample_rate,
|
||||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||||
)
|
)
|
||||||
# pylint: disable=W0101,W0105
|
|
||||||
self.audio_transform = torchaudio.transforms.Resample(
|
|
||||||
orig_freq=self.config.audio.sample_rate,
|
|
||||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_speaker_embedding(self):
|
def _init_speaker_embedding(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
@ -811,6 +810,13 @@ class Vits(BaseTTS):
|
||||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||||
) # pylint: disable=W0201
|
) # pylint: disable=W0201
|
||||||
|
|
||||||
|
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||||
|
"""Freeze layers at the beginning of an epoch"""
|
||||||
|
self._freeze_layers()
|
||||||
|
# set the device of speaker encoder
|
||||||
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
|
self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device)
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
"""Reinit layes if needed"""
|
"""Reinit layes if needed"""
|
||||||
if self.args.reinit_DP:
|
if self.args.reinit_DP:
|
||||||
|
@ -1231,8 +1237,6 @@ class Vits(BaseTTS):
|
||||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._freeze_layers()
|
|
||||||
|
|
||||||
spec_lens = batch["spec_lens"]
|
spec_lens = batch["spec_lens"]
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
|
|
Loading…
Reference in New Issue