mirror of https://github.com/coqui-ai/TTS.git
Update VITS model
This commit is contained in:
parent
638091f41d
commit
7129b04d46
|
@ -385,24 +385,29 @@ class Vits(BaseTTS):
|
|||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
config = config.model_args
|
||||
self.num_speakers = self.args.num_speakers
|
||||
|
||||
self.num_speakers = config.num_speakers
|
||||
if self.speaker_manager:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
if self.args.use_speaker_embedding:
|
||||
self._init_speaker_embedding(config)
|
||||
|
||||
if config.use_d_vector_file:
|
||||
if self.args.use_d_vector_file:
|
||||
self._init_d_vector(config)
|
||||
|
||||
# TODO: make this a function
|
||||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.speaker_manager.speaker_encoder is None and (
|
||||
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
|
||||
):
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
|
@ -412,7 +417,8 @@ class Vits(BaseTTS):
|
|||
|
||||
if (
|
||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
and self.config.audio["sample_rate"]
|
||||
!= self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
# TODO: change this with torchaudio Resample
|
||||
raise RuntimeError(
|
||||
|
@ -434,14 +440,14 @@ class Vits(BaseTTS):
|
|||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
def init_multilingual(self, config: Coqpit):
|
||||
"""Initialize multilingual modules of a model.
|
||||
|
@ -449,15 +455,12 @@ class Vits(BaseTTS):
|
|||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
|
||||
if config.language_ids_file is not None:
|
||||
if self.args.language_ids_file is not None:
|
||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
|
||||
if config.use_language_embedding and self.language_manager:
|
||||
if self.args.use_language_embedding and self.language_manager:
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
self.embedded_language_dim = config.embedded_language_dim
|
||||
self.embedded_language_dim = self.args.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||
else:
|
||||
|
@ -486,7 +489,7 @@ class Vits(BaseTTS):
|
|||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid}
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
|
@ -542,8 +545,8 @@ class Vits(BaseTTS):
|
|||
x_lengths: torch.tensor,
|
||||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
waveform: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
waveform=None,
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
|
@ -552,6 +555,7 @@ class Vits(BaseTTS):
|
|||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||
y (torch.tensor): Batch of input spectrograms.
|
||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||
waveform (torch.tensor): Batch of ground truth waveforms per sample.
|
||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
||||
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
||||
|
||||
|
@ -563,6 +567,7 @@ class Vits(BaseTTS):
|
|||
- x_lengths: :math:`[B]`
|
||||
- y: :math:`[B, C, T_spec]`
|
||||
- y_lengths: :math:`[B]`
|
||||
- waveform: :math:`[B, T_wav, 1]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
- language_ids: :math:`[B]`
|
||||
|
@ -628,14 +633,14 @@ class Vits(BaseTTS):
|
|||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
# pylint: disable=W0105
|
||||
|
@ -712,20 +717,29 @@ class Vits(BaseTTS):
|
|||
return outputs
|
||||
|
||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
||||
"""TODO: create an end-point for voice conversion"""
|
||||
"""Forward pass for voice conversion
|
||||
|
||||
TODO: create an end-point for voice conversion
|
||||
|
||||
Args:
|
||||
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
|
||||
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
|
||||
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
|
||||
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
|
||||
"""
|
||||
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
||||
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and not self.use_d_vector:
|
||||
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
||||
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
|
||||
elif self.args.use_speaker_embedding and self.use_d_vector:
|
||||
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
||||
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
||||
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||
|
||||
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
|
||||
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
|
@ -786,8 +800,8 @@ class Vits(BaseTTS):
|
|||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
waveform.transpose(1, 2),
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
waveform=waveform,
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
|
|
Loading…
Reference in New Issue