mirror of https://github.com/coqui-ai/TTS.git
Implement vocoder Fine Tuning like SC-GlowTTS paper
This commit is contained in:
parent
3df5d9a619
commit
9071bf326f
|
@ -598,6 +598,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
feats_disc_fake,
|
||||
feats_disc_real,
|
||||
loss_duration,
|
||||
fine_tuning_mode=False,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -619,9 +620,15 @@ class VitsGeneratorLoss(nn.Module):
|
|||
mel = self.stft(waveform)
|
||||
mel_hat = self.stft(waveform_hat)
|
||||
# compute losses
|
||||
|
||||
# ignore tts model loss if fine tunning mode is on
|
||||
if fine_tuning_mode:
|
||||
loss_kl = 0.0
|
||||
else:
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
|
||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
|
|
@ -193,6 +193,7 @@ class VitsArgs(Coqpit):
|
|||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
num_languages: int = 0
|
||||
fine_tuning_mode: bool = False
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -330,6 +331,7 @@ class Vits(BaseTTS):
|
|||
|
||||
if args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||
print("FINE TUNING:", self.args.fine_tuning_mode)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
|
@ -521,6 +523,90 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return outputs
|
||||
|
||||
def forward_fine_tuning(
|
||||
self,
|
||||
x: torch.tensor,
|
||||
x_lengths: torch.tensor,
|
||||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (torch.tensor): Batch of input character sequence IDs.
|
||||
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||
y (torch.tensor): Batch of input spectrograms.
|
||||
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs keyed by the output name.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- y: :math:`[B, C, T_spec]`
|
||||
- y_lengths: :math:`[B]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs = {}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb=None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
||||
# get the z after inverse decoder
|
||||
# ToDo: test if using m_p the result is better (In the SC-GlowTTS paper we used mp instead z_p)
|
||||
z_f_pred = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size)
|
||||
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"slice_ids": slice_ids,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"m_q": m_q,
|
||||
"logs_q": logs_q,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -599,6 +685,15 @@ class Vits(BaseTTS):
|
|||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
# generator pass
|
||||
if self.args.fine_tuning_mode:
|
||||
# ToDo: find better place fot it
|
||||
# force eval mode
|
||||
self.eval()
|
||||
# restore train mode for the vocoder part
|
||||
self.waveform_decoder.train()
|
||||
self.disc.train()
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -610,13 +705,24 @@ class Vits(BaseTTS):
|
|||
waveform = batch["waveform"]
|
||||
|
||||
# generator pass
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
if self.args.fine_tuning_mode:
|
||||
|
||||
# model forward
|
||||
outputs = self.forward_fine_tuning(
|
||||
text_input,
|
||||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
else:
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
self.y_disc_cache = None
|
||||
|
@ -649,15 +755,17 @@ class Vits(BaseTTS):
|
|||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"],
|
||||
fine_tuning_mode=self.args.fine_tuning_mode,
|
||||
)
|
||||
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["loss_duration"]
|
||||
# ignore duration loss if fine tuning mode is on
|
||||
if not self.args.fine_tuning_mode:
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["loss_duration"]
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
|
@ -853,3 +961,5 @@ class Vits(BaseTTS):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue