mirror of https://github.com/coqui-ai/TTS.git
Refactor VITS model
This commit is contained in:
parent
ef63c99524
commit
2829027d8b
|
@ -38,7 +38,7 @@ class VitsArgs(Coqpit):
|
|||
Number of characters in the vocabulary. Defaults to 100.
|
||||
|
||||
out_channels (int):
|
||||
Number of output channels. Defaults to 513.
|
||||
Number of output channels of the decoder. Defaults to 513.
|
||||
|
||||
spec_segment_size (int):
|
||||
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
|
||||
|
@ -363,6 +363,8 @@ class Vits(BaseTTS):
|
|||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
upsample_rate = math.prod(self.args.upsample_rates_decoder)
|
||||
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
|
@ -531,6 +533,54 @@ class Vits(BaseTTS):
|
|||
"language_name": language_name,
|
||||
}
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
speaker_ids = aux_input.get("speaker_ids", None)
|
||||
|
||||
if d_vectors is not None and speaker_ids is not None:
|
||||
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
||||
|
||||
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
||||
|
||||
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||
return g
|
||||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
if self.args.use_sdp:
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.tensor,
|
||||
|
@ -596,54 +646,27 @@ class Vits(BaseTTS):
|
|||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
g_dp = None
|
||||
if self.args.condition_dp_on_speaker:
|
||||
g_dp = g.detach() if self.args.detach_dp_input and g is not None else g
|
||||
if self.args.use_sdp:
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g_dp,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g_dp,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
if self.args.use_mas:
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
elif self.args.use_aligner_network:
|
||||
outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
outputs["x_lens"] = x_lengths
|
||||
outputs["y_lens"] = y_lengths
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size)
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short=True
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
|
@ -665,11 +688,11 @@ class Vits(BaseTTS):
|
|||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"alignments" : attn.squeeze(1),
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_q": m_q,
|
||||
"logs_q": logs_q,
|
||||
"waveform_seg": wav_seg,
|
||||
|
@ -919,14 +942,18 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
self._log(self.ap, batch, outputs, "train")
|
||||
figures, audios = self._log(self.ap, batch, outputs, "train")
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
return self._log(self.ap, batch, outputs, "eval")
|
||||
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets) -> Tuple[Dict, Dict]:
|
||||
|
|
Loading…
Reference in New Issue