mirror of https://github.com/coqui-ai/TTS.git
Implement vocoder Fine Tuning like SC-GlowTTS paper
This commit is contained in:
parent
3df5d9a619
commit
9071bf326f
|
@ -598,6 +598,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
feats_disc_fake,
|
feats_disc_fake,
|
||||||
feats_disc_real,
|
feats_disc_real,
|
||||||
loss_duration,
|
loss_duration,
|
||||||
|
fine_tuning_mode=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -619,9 +620,15 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
mel = self.stft(waveform)
|
mel = self.stft(waveform)
|
||||||
mel_hat = self.stft(waveform_hat)
|
mel_hat = self.stft(waveform_hat)
|
||||||
# compute losses
|
# compute losses
|
||||||
|
|
||||||
|
# ignore tts model loss if fine tunning mode is on
|
||||||
|
if fine_tuning_mode:
|
||||||
|
loss_kl = 0.0
|
||||||
|
else:
|
||||||
|
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||||
|
|
||||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
|
||||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
|
|
@ -193,6 +193,7 @@ class VitsArgs(Coqpit):
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
embedded_language_dim: int = 4
|
embedded_language_dim: int = 4
|
||||||
num_languages: int = 0
|
num_languages: int = 0
|
||||||
|
fine_tuning_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Vits(BaseTTS):
|
class Vits(BaseTTS):
|
||||||
|
@ -330,6 +331,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if args.init_discriminator:
|
if args.init_discriminator:
|
||||||
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||||
|
print("FINE TUNING:", self.args.fine_tuning_mode)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
@ -521,6 +523,90 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def forward_fine_tuning(
|
||||||
|
self,
|
||||||
|
x: torch.tensor,
|
||||||
|
x_lengths: torch.tensor,
|
||||||
|
y: torch.tensor,
|
||||||
|
y_lengths: torch.tensor,
|
||||||
|
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||||
|
) -> Dict:
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.tensor): Batch of input character sequence IDs.
|
||||||
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
||||||
|
y (torch.tensor): Batch of input spectrograms.
|
||||||
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
||||||
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: model outputs keyed by the output name.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_seq]`
|
||||||
|
- x_lengths: :math:`[B]`
|
||||||
|
- y: :math:`[B, C, T_spec]`
|
||||||
|
- y_lengths: :math:`[B]`
|
||||||
|
- d_vectors: :math:`[B, C, 1]`
|
||||||
|
- speaker_ids: :math:`[B]`
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = {}
|
||||||
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
|
# speaker embedding
|
||||||
|
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
|
||||||
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
|
# language embedding
|
||||||
|
lang_emb=None
|
||||||
|
if self.args.use_language_embedding and lid is not None:
|
||||||
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
|
# posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
# flow layers
|
||||||
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# find the alignment path
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
|
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||||
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
|
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp = logp2 + logp3
|
||||||
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
|
# expand prior
|
||||||
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||||
|
|
||||||
|
# get the z after inverse decoder
|
||||||
|
# ToDo: test if using m_p the result is better (In the SC-GlowTTS paper we used mp instead z_p)
|
||||||
|
z_f_pred = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
z_slice, slice_ids = rand_segment(z_f_pred, y_lengths, self.spec_segment_size)
|
||||||
|
|
||||||
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
outputs.update(
|
||||||
|
{
|
||||||
|
"model_outputs": o,
|
||||||
|
"alignments": attn.squeeze(1),
|
||||||
|
"slice_ids": slice_ids,
|
||||||
|
"z": z,
|
||||||
|
"z_p": z_p,
|
||||||
|
"m_p": m_p,
|
||||||
|
"logs_p": logs_p,
|
||||||
|
"m_q": m_q,
|
||||||
|
"logs_q": logs_q,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -599,6 +685,15 @@ class Vits(BaseTTS):
|
||||||
if optimizer_idx not in [0, 1]:
|
if optimizer_idx not in [0, 1]:
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
|
# generator pass
|
||||||
|
if self.args.fine_tuning_mode:
|
||||||
|
# ToDo: find better place fot it
|
||||||
|
# force eval mode
|
||||||
|
self.eval()
|
||||||
|
# restore train mode for the vocoder part
|
||||||
|
self.waveform_decoder.train()
|
||||||
|
self.disc.train()
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -610,6 +705,17 @@ class Vits(BaseTTS):
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
|
if self.args.fine_tuning_mode:
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
outputs = self.forward_fine_tuning(
|
||||||
|
text_input,
|
||||||
|
text_lengths,
|
||||||
|
linear_input.transpose(1, 2),
|
||||||
|
mel_lengths,
|
||||||
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||||
|
)
|
||||||
|
else:
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
text_input,
|
text_input,
|
||||||
text_lengths,
|
text_lengths,
|
||||||
|
@ -649,8 +755,10 @@ class Vits(BaseTTS):
|
||||||
feats_disc_fake=outputs["feats_disc_fake"],
|
feats_disc_fake=outputs["feats_disc_fake"],
|
||||||
feats_disc_real=outputs["feats_disc_real"],
|
feats_disc_real=outputs["feats_disc_real"],
|
||||||
loss_duration=outputs["loss_duration"],
|
loss_duration=outputs["loss_duration"],
|
||||||
|
fine_tuning_mode=self.args.fine_tuning_mode,
|
||||||
)
|
)
|
||||||
|
# ignore duration loss if fine tuning mode is on
|
||||||
|
if not self.args.fine_tuning_mode:
|
||||||
# handle the duration loss
|
# handle the duration loss
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||||
|
@ -853,3 +961,5 @@ class Vits(BaseTTS):
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue