mirror of https://github.com/coqui-ai/TTS.git
Remove the unusable fine-tuning model
This commit is contained in:
parent
352aa69eca
commit
6fc3b9e679
|
@ -599,7 +599,6 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_fake,
|
feats_disc_fake,
|
||||||
feats_disc_real,
|
feats_disc_real,
|
||||||
loss_duration,
|
loss_duration,
|
||||||
fine_tuning_mode=0,
|
|
||||||
use_speaker_encoder_as_loss=False,
|
use_speaker_encoder_as_loss=False,
|
||||||
gt_spk_emb=None,
|
gt_spk_emb=None,
|
||||||
syn_spk_emb=None,
|
syn_spk_emb=None,
|
||||||
|
@ -623,14 +622,9 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
# compute mel spectrograms from the waveforms
|
# compute mel spectrograms from the waveforms
|
||||||
mel = self.stft(waveform)
|
mel = self.stft(waveform)
|
||||||
mel_hat = self.stft(waveform_hat)
|
mel_hat = self.stft(waveform_hat)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
|
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||||
# ignore tts model loss if fine tunning mode is on
|
|
||||||
if fine_tuning_mode:
|
|
||||||
loss_kl = 0.0
|
|
||||||
else:
|
|
||||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
|
||||||
|
|
||||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||||
|
|
|
@ -167,11 +167,6 @@ class VitsArgs(Coqpit):
|
||||||
speaker_encoder_model_path (str):
|
speaker_encoder_model_path (str):
|
||||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||||
|
|
||||||
fine_tuning_mode (int):
|
|
||||||
Fine tuning only the vocoder part of the model, while the rest will be frozen. Defaults to 0.
|
|
||||||
Mode 0: Disabled;
|
|
||||||
Mode 1: uses the distribution predicted by the encoder and It's recommended for TTS;
|
|
||||||
Mode 2: uses the distribution predicted by the encoder and It's recommended for voice conversion.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_chars: int = 100
|
num_chars: int = 100
|
||||||
|
@ -219,7 +214,6 @@ class VitsArgs(Coqpit):
|
||||||
use_speaker_encoder_as_loss: bool = False
|
use_speaker_encoder_as_loss: bool = False
|
||||||
speaker_encoder_config_path: str = ""
|
speaker_encoder_config_path: str = ""
|
||||||
speaker_encoder_model_path: str = ""
|
speaker_encoder_model_path: str = ""
|
||||||
fine_tuning_mode: int = 0
|
|
||||||
freeze_encoder: bool = False
|
freeze_encoder: bool = False
|
||||||
freeze_DP: bool = False
|
freeze_DP: bool = False
|
||||||
freeze_PE: bool = False
|
freeze_PE: bool = False
|
||||||
|
@ -672,122 +666,6 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward_fine_tuning(
|
|
||||||
self,
|
|
||||||
x: torch.tensor,
|
|
||||||
x_lengths: torch.tensor,
|
|
||||||
y: torch.tensor,
|
|
||||||
y_lengths: torch.tensor,
|
|
||||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
|
||||||
waveform=None,
|
|
||||||
) -> Dict:
|
|
||||||
"""Forward pass of the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.tensor): Batch of input character sequence IDs.
|
|
||||||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
|
||||||
y (torch.tensor): Batch of input spectrograms.
|
|
||||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
|
||||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: model outputs keyed by the output name.
|
|
||||||
|
|
||||||
Shapes:
|
|
||||||
- x: :math:`[B, T_seq]`
|
|
||||||
- x_lengths: :math:`[B]`
|
|
||||||
- y: :math:`[B, C, T_spec]`
|
|
||||||
- y_lengths: :math:`[B]`
|
|
||||||
- d_vectors: :math:`[B, C, 1]`
|
|
||||||
- speaker_ids: :math:`[B]`
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = {}
|
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
|
||||||
# speaker embedding
|
|
||||||
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
|
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
|
||||||
|
|
||||||
# language embedding
|
|
||||||
lang_emb = None
|
|
||||||
if self.args.use_language_embedding and lid is not None:
|
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
|
||||||
|
|
||||||
# posterior encoder
|
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
|
||||||
|
|
||||||
# flow layers
|
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
|
||||||
|
|
||||||
# find the alignment path
|
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
||||||
with torch.no_grad():
|
|
||||||
o_scale = torch.exp(-2 * logs_p)
|
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
|
||||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
|
||||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
|
||||||
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
|
||||||
logp = logp2 + logp3 + logp1 + logp4
|
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
|
||||||
|
|
||||||
# expand prior
|
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
|
||||||
|
|
||||||
# mode 1: like SC-GlowTTS paper; mode 2: recommended for voice conversion
|
|
||||||
if self.args.fine_tuning_mode == 1:
|
|
||||||
z_ft = m_p
|
|
||||||
elif self.args.fine_tuning_mode == 2:
|
|
||||||
z_ft = z_p
|
|
||||||
else:
|
|
||||||
raise RuntimeError(" [!] Invalid Fine Tunning Mode !")
|
|
||||||
|
|
||||||
# inverse decoder and get the output
|
|
||||||
z_f_pred = self.flow(z_ft, y_mask, g=g, reverse=True)
|
|
||||||
z_slice, slice_ids = rand_segments(z_f_pred, y_lengths, self.spec_segment_size)
|
|
||||||
|
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
|
||||||
|
|
||||||
wav_seg = segment(
|
|
||||||
waveform.transpose(1, 2),
|
|
||||||
slice_ids * self.config.audio.hop_length,
|
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
|
||||||
# concate generated and GT waveforms
|
|
||||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
|
||||||
|
|
||||||
# resample audio to speaker encoder sample_rate
|
|
||||||
if self.audio_transform is not None:
|
|
||||||
wavs_batch = self.audio_transform(wavs_batch)
|
|
||||||
|
|
||||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
|
||||||
|
|
||||||
# split generated and GT speaker embeddings
|
|
||||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
|
||||||
else:
|
|
||||||
gt_spk_emb, syn_spk_emb = None, None
|
|
||||||
|
|
||||||
outputs.update(
|
|
||||||
{
|
|
||||||
"model_outputs": o,
|
|
||||||
"alignments": attn.squeeze(1),
|
|
||||||
"loss_duration": 0.0,
|
|
||||||
"z": z,
|
|
||||||
"z_p": z_p,
|
|
||||||
"m_p": m_p,
|
|
||||||
"logs_p": logs_p,
|
|
||||||
"m_q": m_q,
|
|
||||||
"logs_q": logs_q,
|
|
||||||
"waveform_seg": wav_seg,
|
|
||||||
"gt_spk_emb": gt_spk_emb,
|
|
||||||
"syn_spk_emb": syn_spk_emb,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||||
"""
|
"""
|
||||||
|
@ -869,15 +747,6 @@ class Vits(BaseTTS):
|
||||||
if optimizer_idx not in [0, 1]:
|
if optimizer_idx not in [0, 1]:
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
# generator pass
|
|
||||||
if self.args.fine_tuning_mode:
|
|
||||||
# ToDo: find better place fot it
|
|
||||||
# force eval mode
|
|
||||||
self.eval()
|
|
||||||
# restore train mode for the vocoder part
|
|
||||||
self.waveform_decoder.train()
|
|
||||||
self.disc.train()
|
|
||||||
|
|
||||||
if self.args.freeze_encoder:
|
if self.args.freeze_encoder:
|
||||||
for param in self.text_encoder.parameters():
|
for param in self.text_encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
@ -913,25 +782,14 @@ class Vits(BaseTTS):
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
if self.args.fine_tuning_mode:
|
outputs = self.forward(
|
||||||
# model forward
|
text_input,
|
||||||
outputs = self.forward_fine_tuning(
|
text_lengths,
|
||||||
text_input,
|
linear_input.transpose(1, 2),
|
||||||
text_lengths,
|
mel_lengths,
|
||||||
linear_input.transpose(1, 2),
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||||
mel_lengths,
|
waveform=waveform,
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
)
|
||||||
waveform=waveform,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs = self.forward(
|
|
||||||
text_input,
|
|
||||||
text_lengths,
|
|
||||||
linear_input.transpose(1, 2),
|
|
||||||
mel_lengths,
|
|
||||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
|
||||||
waveform=waveform,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cache tensors for the discriminator
|
# cache tensors for the discriminator
|
||||||
self.y_disc_cache = None
|
self.y_disc_cache = None
|
||||||
|
@ -958,7 +816,6 @@ class Vits(BaseTTS):
|
||||||
feats_disc_fake=outputs["feats_disc_fake"],
|
feats_disc_fake=outputs["feats_disc_fake"],
|
||||||
feats_disc_real=outputs["feats_disc_real"],
|
feats_disc_real=outputs["feats_disc_real"],
|
||||||
loss_duration=outputs["loss_duration"],
|
loss_duration=outputs["loss_duration"],
|
||||||
fine_tuning_mode=self.args.fine_tuning_mode,
|
|
||||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||||
gt_spk_emb=outputs["gt_spk_emb"],
|
gt_spk_emb=outputs["gt_spk_emb"],
|
||||||
syn_spk_emb=outputs["syn_spk_emb"],
|
syn_spk_emb=outputs["syn_spk_emb"],
|
||||||
|
|
Loading…
Reference in New Issue