From b78431d7e270f95e477ee1d39b3485cba96b6b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 3 May 2022 00:34:55 +0200 Subject: [PATCH] Fix MAS --- TTS/tts/layers/losses.py | 4 +-- TTS/tts/models/vits.py | 41 ++++++++++++++------------ TTS/tts/utils/helpers.py | 6 ++-- TTS/tts/utils/monotonic_align/core.pyx | 25 +++++++--------- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e03cf084..fca00f0b 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module): @staticmethod def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): """ - z_p, logs_q: [b, h, t_t] - m_p, logs_p: [b, h, t_t] + z_p, logs_q: [B, C, T_de] + m_p, logs_p: [B, C, T_de] """ z_p = z_p.float() logs_q = logs_q.float() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 613e4eae..55c39c3b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -735,6 +735,7 @@ class Vits(BaseTTS): """ if self.args.encoder_sample_rate: self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer." self.audio_resampler = torchaudio.transforms.Resample( orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate ) # pylint: disable=W0201 @@ -803,18 +804,18 @@ class Vits(BaseTTS): def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [B, 1, T_en] + logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale]) # [B, T_de, T_en] + logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale]) # [B, T_de, T_en] + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [B, 1, T_en] logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en] # duration predictor - attn_durations = attn.sum(3) + attn_durations = attn.sum(2) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, @@ -845,7 +846,7 @@ class Vits(BaseTTS): # interpolate z if needed if self.args.interpolate_z: z = torch.nn.functional.interpolate( - z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest" + z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear" ).squeeze(0) # recompute the mask if needed if y_lengths is not None and y_mask is not None: @@ -890,7 +891,7 @@ class Vits(BaseTTS): Return Shapes: - model_outputs: :math:`[B, 1, T_wav]` - - alignments: :math:`[B, T_seq, T_dec]` + - alignments: :math:`[B, T_dec, T_seq]` - z: :math:`[B, C, T_dec]` - z_p: :math:`[B, C, T_dec]` - m_p: :math:`[B, C, T_dec]` @@ -924,8 +925,8 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p]) # [B, 1, T_de, T_en] -> [B, C, T_de] + logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de] # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -982,6 +983,7 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + @torch.no_grad() def inference( self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} ): # pylint: disable=dangerous-default-value @@ -1053,6 +1055,7 @@ class Vits(BaseTTS): outputs = { "model_outputs": o, "alignments": attn.squeeze(1), + "durations": w_ceil, "z": z, "z_p": z_p, "m_p": m_p, @@ -1231,7 +1234,7 @@ class Vits(BaseTTS): audios = {f"{name_prefix}/audio": sample_voice} alignments = outputs[1]["alignments"] - align_img = alignments[0].data.cpu().numpy().T + align_img = alignments[0].data.cpu().numpy() figures.update( { @@ -1390,16 +1393,16 @@ class Vits(BaseTTS): """Compute spectrograms on the device.""" ac = self.config.audio - if self.args.encoder_sample_rate: - wav = self.audio_resampler(batch["waveform"]) - else: - wav = batch["waveform"] - # compute spectrograms - batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + if self.args.encoder_sample_rate: + # downsample audio to encoder sample rate + wav = self.audio_resampler(batch["waveform"]) + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + else: + batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) if self.args.encoder_sample_rate: - # recompute spec with high sampling rate to the loss + # recompute spec with vocoder sampling rate spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) # remove extra stft frame spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c2e7f561..3c8bc2c8 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -180,10 +180,10 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): """Cython optimised version. Shapes: - - value: :math:`[B, T_en, T_de]` - - mask: :math:`[B, T_en, T_de]` + - value: :math:`[B, T_de, T_en]` + - mask: :math:`[B, T_de, T_en]` """ - value = value * mask + # value = value * mask device = value.device dtype = value.dtype value = value.data.cpu().numpy().astype(np.float32) diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx index 091fcc3a..227fc1ed 100644 --- a/TTS/tts/utils/monotonic_align/core.pyx +++ b/TTS/tts/utils/monotonic_align/core.pyx @@ -1,14 +1,10 @@ -import numpy as np - -cimport cython -cimport numpy as np - +import cython from cython.parallel import prange @cython.boundscheck(False) @cython.wraparound(False) -cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: cdef int x cdef int y cdef float v_prev @@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_ if x == y: v_cur = max_neg_val else: - v_cur = value[x, y-1] + v_cur = value[y-1, x] if x == 0: if y == 0: v_prev = 0. else: v_prev = max_neg_val else: - v_prev = value[x-1, y-1] - value[x, y] = max(v_cur, v_prev) + value[x, y] + v_prev = value[y-1, x-1] + value[y, x] += max(v_prev, v_cur) for y in range(t_y - 1, -1, -1): - path[index, y] = 1 - if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + path[y, index] = 1 + if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): index = index - 1 @cython.boundscheck(False) @cython.wraparound(False) -cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: - cdef int b = values.shape[0] - +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] cdef int i for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])