Refactor VITS model

This commit is contained in:
Eren Gölge 2022-01-21 15:33:15 +00:00
parent ef63c99524
commit 2829027d8b
1 changed files with 68 additions and 41 deletions

View File

@ -38,7 +38,7 @@ class VitsArgs(Coqpit):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 513.
Number of output channels of the decoder. Defaults to 513.
spec_segment_size (int):
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
@ -363,6 +363,8 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim,
)
upsample_rate = math.prod(self.args.upsample_rates_decoder)
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels,
1,
@ -531,6 +533,54 @@ class Vits(BaseTTS):
"language_name": language_name,
}
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
if d_vectors is not None and speaker_ids is not None:
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
if speaker_ids is not None and not hasattr(self, "emb_g"):
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
g = speaker_ids if speaker_ids is not None else d_vectors
return g
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
return outputs, attn
def forward(
self,
x: torch.tensor,
@ -596,54 +646,27 @@ class Vits(BaseTTS):
# flow layers
z_p = self.flow(z, y_mask, g=g)
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
attn_durations = attn.sum(3)
g_dp = None
if self.args.condition_dp_on_speaker:
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
if self.args.use_sdp:
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g_dp,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
if self.args.use_mas:
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
elif self.args.use_aligner_network:
outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb)
outputs["x_lens"] = x_lengths
outputs["y_lens"] = y_lengths
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform,
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
pad_short=True
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
@ -665,11 +688,11 @@ class Vits(BaseTTS):
outputs.update(
{
"model_outputs": o,
"alignments": attn.squeeze(1),
"z": z,
"z_p": z_p,
"alignments" : attn.squeeze(1),
"m_p": m_p,
"logs_p": logs_p,
"z": z,
"z_p": z_p,
"m_q": m_q,
"logs_q": logs_q,
"waveform_seg": wav_seg,
@ -919,14 +942,18 @@ class Vits(BaseTTS):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
self._log(self.ap, batch, outputs, "train")
figures, audios = self._log(self.ap, batch, outputs, "train")
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
return self._log(self.ap, batch, outputs, "eval")
figures, audios = self._log(self.ap, batch, outputs, "eval")
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: