mirror of https://github.com/coqui-ai/TTS.git
Update VITS model
This commit is contained in:
parent
638091f41d
commit
7129b04d46
|
@ -385,24 +385,29 @@ class Vits(BaseTTS):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.embedded_speaker_dim = 0
|
self.embedded_speaker_dim = 0
|
||||||
config = config.model_args
|
self.num_speakers = self.args.num_speakers
|
||||||
|
|
||||||
self.num_speakers = config.num_speakers
|
if self.speaker_manager:
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
|
||||||
if config.use_speaker_embedding:
|
if self.args.use_speaker_embedding:
|
||||||
self._init_speaker_embedding(config)
|
self._init_speaker_embedding(config)
|
||||||
|
|
||||||
if config.use_d_vector_file:
|
if self.args.use_d_vector_file:
|
||||||
self._init_d_vector(config)
|
self._init_d_vector(config)
|
||||||
|
|
||||||
# TODO: make this a function
|
# TODO: make this a function
|
||||||
if config.use_speaker_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
if self.speaker_manager.speaker_encoder is None and (
|
||||||
|
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
|
||||||
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||||
)
|
)
|
||||||
|
@ -412,7 +417,8 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
and self.config.audio["sample_rate"]
|
||||||
|
!= self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||||
):
|
):
|
||||||
# TODO: change this with torchaudio Resample
|
# TODO: change this with torchaudio Resample
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -434,14 +440,14 @@ class Vits(BaseTTS):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if self.num_speakers > 0:
|
if self.num_speakers > 0:
|
||||||
print(" > initialization of speaker-embedding layers.")
|
print(" > initialization of speaker-embedding layers.")
|
||||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
|
||||||
def _init_d_vector(self, config):
|
def _init_d_vector(self, config):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||||
|
|
||||||
def init_multilingual(self, config: Coqpit):
|
def init_multilingual(self, config: Coqpit):
|
||||||
"""Initialize multilingual modules of a model.
|
"""Initialize multilingual modules of a model.
|
||||||
|
@ -449,15 +455,12 @@ class Vits(BaseTTS):
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
"""
|
"""
|
||||||
if hasattr(config, "model_args"):
|
if self.args.language_ids_file is not None:
|
||||||
config = config.model_args
|
|
||||||
|
|
||||||
if config.language_ids_file is not None:
|
|
||||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
|
|
||||||
if config.use_language_embedding and self.language_manager:
|
if self.args.use_language_embedding and self.language_manager:
|
||||||
self.num_languages = self.language_manager.num_languages
|
self.num_languages = self.language_manager.num_languages
|
||||||
self.embedded_language_dim = config.embedded_language_dim
|
self.embedded_language_dim = self.args.embedded_language_dim
|
||||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||||
else:
|
else:
|
||||||
|
@ -486,7 +489,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
|
||||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||||
if hasattr(self.config, "model_args"):
|
if hasattr(self.config, "model_args"):
|
||||||
|
@ -542,8 +545,8 @@ class Vits(BaseTTS):
|
||||||
x_lengths: torch.tensor,
|
x_lengths: torch.tensor,
|
||||||
y: torch.tensor,
|
y: torch.tensor,
|
||||||
y_lengths: torch.tensor,
|
y_lengths: torch.tensor,
|
||||||
|
waveform: torch.tensor,
|
||||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||||
waveform=None,
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Forward pass of the model.
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
@ -552,6 +555,7 @@ class Vits(BaseTTS):
|
||||||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||||
y (torch.tensor): Batch of input spectrograms.
|
y (torch.tensor): Batch of input spectrograms.
|
||||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||||
|
waveform (torch.tensor): Batch of ground truth waveforms per sample.
|
||||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
||||||
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
||||||
|
|
||||||
|
@ -563,6 +567,7 @@ class Vits(BaseTTS):
|
||||||
- x_lengths: :math:`[B]`
|
- x_lengths: :math:`[B]`
|
||||||
- y: :math:`[B, C, T_spec]`
|
- y: :math:`[B, C, T_spec]`
|
||||||
- y_lengths: :math:`[B]`
|
- y_lengths: :math:`[B]`
|
||||||
|
- waveform: :math:`[B, T_wav, 1]`
|
||||||
- d_vectors: :math:`[B, C, 1]`
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
- speaker_ids: :math:`[B]`
|
- speaker_ids: :math:`[B]`
|
||||||
- language_ids: :math:`[B]`
|
- language_ids: :math:`[B]`
|
||||||
|
@ -628,14 +633,14 @@ class Vits(BaseTTS):
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
wav_seg = segment(
|
wav_seg = segment(
|
||||||
waveform.transpose(1, 2),
|
waveform,
|
||||||
slice_ids * self.config.audio.hop_length,
|
slice_ids * self.config.audio.hop_length,
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||||
# concate generated and GT waveforms
|
# concate generated and GT waveforms
|
||||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||||
|
|
||||||
# resample audio to speaker encoder sample_rate
|
# resample audio to speaker encoder sample_rate
|
||||||
# pylint: disable=W0105
|
# pylint: disable=W0105
|
||||||
|
@ -712,20 +717,29 @@ class Vits(BaseTTS):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||||
"""TODO: create an end-point for voice conversion"""
|
"""Forward pass for voice conversion
|
||||||
|
|
||||||
|
TODO: create an end-point for voice conversion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
|
||||||
|
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
|
||||||
|
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
|
||||||
|
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
|
||||||
|
"""
|
||||||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.args.use_speaker_embedding and not self.use_d_vector:
|
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||||
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
||||||
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
||||||
elif self.args.use_speaker_embedding and self.use_d_vector:
|
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
||||||
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
||||||
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||||
|
|
||||||
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
|
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
|
||||||
z_p = self.flow(z, y_mask, g=g_src)
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
|
@ -786,8 +800,8 @@ class Vits(BaseTTS):
|
||||||
text_lengths,
|
text_lengths,
|
||||||
linear_input.transpose(1, 2),
|
linear_input.transpose(1, 2),
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
|
waveform.transpose(1, 2),
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||||
waveform=waveform,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# cache tensors for the discriminator
|
# cache tensors for the discriminator
|
||||||
|
|
Loading…
Reference in New Issue