Update VITS model

This commit is contained in:
Eren Gölge 2021-12-30 12:02:35 +00:00
parent 638091f41d
commit 7129b04d46
1 changed files with 38 additions and 24 deletions

View File

@ -385,24 +385,29 @@ class Vits(BaseTTS):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
config = config.model_args
self.num_speakers = self.args.num_speakers
self.num_speakers = config.num_speakers
if self.speaker_manager:
self.num_speakers = self.speaker_manager.num_speakers
if config.use_speaker_embedding:
if self.args.use_speaker_embedding:
self._init_speaker_embedding(config)
if config.use_d_vector_file:
if self.args.use_d_vector_file:
self._init_d_vector(config)
# TODO: make this a function
if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.speaker_encoder is None and (
not config.speaker_encoder_model_path or not config.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
@ -412,7 +417,8 @@ class Vits(BaseTTS):
if (
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
and self.config.audio["sample_rate"]
!= self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
@ -434,14 +440,14 @@ class Vits(BaseTTS):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = config.speaker_embedding_channels
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
def _init_d_vector(self, config):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.embedded_speaker_dim = config.d_vector_dim
self.embedded_speaker_dim = self.args.d_vector_dim
def init_multilingual(self, config: Coqpit):
"""Initialize multilingual modules of a model.
@ -449,15 +455,12 @@ class Vits(BaseTTS):
Args:
config (Coqpit): Model configuration.
"""
if hasattr(config, "model_args"):
config = config.model_args
if config.language_ids_file is not None:
if self.args.language_ids_file is not None:
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if config.use_language_embedding and self.language_manager:
if self.args.use_language_embedding and self.language_manager:
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = config.embedded_language_dim
self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
torch.nn.init.xavier_uniform_(self.emb_l.weight)
else:
@ -486,7 +489,7 @@ class Vits(BaseTTS):
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid}
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self.config, "model_args"):
@ -542,8 +545,8 @@ class Vits(BaseTTS):
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
waveform: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
waveform=None,
) -> Dict:
"""Forward pass of the model.
@ -552,6 +555,7 @@ class Vits(BaseTTS):
x_lengths (torch.tensor): Batch of input character sequence lengths.
y (torch.tensor): Batch of input spectrograms.
y_lengths (torch.tensor): Batch of input spectrogram lengths.
waveform (torch.tensor): Batch of ground truth waveforms per sample.
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
@ -563,6 +567,7 @@ class Vits(BaseTTS):
- x_lengths: :math:`[B]`
- y: :math:`[B, C, T_spec]`
- y_lengths: :math:`[B]`
- waveform: :math:`[B, T_wav, 1]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
- language_ids: :math:`[B]`
@ -628,14 +633,14 @@ class Vits(BaseTTS):
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform.transpose(1, 2),
waveform,
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
wavs_batch = torch.cat((wav_seg, o), dim=0)
# resample audio to speaker encoder sample_rate
# pylint: disable=W0105
@ -712,20 +717,29 @@ class Vits(BaseTTS):
return outputs
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
"""TODO: create an end-point for voice conversion"""
"""Forward pass for voice conversion
TODO: create an end-point for voice conversion
Args:
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
"""
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
# speaker embedding
if self.args.use_speaker_embedding and not self.use_d_vector:
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
g_src = self.emb_g(speaker_cond_src).unsqueeze(-1)
g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1)
elif self.args.use_speaker_embedding and self.use_d_vector:
elif self.args.use_speaker_embedding and self.args.use_d_vector_file:
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
else:
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
@ -786,8 +800,8 @@ class Vits(BaseTTS):
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
waveform.transpose(1, 2),
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
waveform=waveform,
)
# cache tensors for the discriminator