mirror of https://github.com/coqui-ai/TTS.git
Fix MAS
This commit is contained in:
parent
76b274e690
commit
b78431d7e2
|
@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
@staticmethod
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
z_p, logs_q: [b, h, t_t]
|
||||
m_p, logs_p: [b, h, t_t]
|
||||
z_p, logs_q: [B, C, T_de]
|
||||
m_p, logs_p: [B, C, T_de]
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
|
|
|
@ -735,6 +735,7 @@ class Vits(BaseTTS):
|
|||
"""
|
||||
if self.args.encoder_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||
assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer."
|
||||
self.audio_resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
) # pylint: disable=W0201
|
||||
|
@ -803,18 +804,18 @@ class Vits(BaseTTS):
|
|||
|
||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||
# find the alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * logs_p)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
|
||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [B, 1, T_en]
|
||||
logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale]) # [B, T_de, T_en]
|
||||
logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale]) # [B, T_de, T_en]
|
||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [B, 1, T_en]
|
||||
logp = logp2 + logp3 + logp1 + logp4
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en]
|
||||
|
||||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
attn_durations = attn.sum(2)
|
||||
if self.args.use_sdp:
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
|
@ -845,7 +846,7 @@ class Vits(BaseTTS):
|
|||
# interpolate z if needed
|
||||
if self.args.interpolate_z:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
|
||||
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear"
|
||||
).squeeze(0)
|
||||
# recompute the mask if needed
|
||||
if y_lengths is not None and y_mask is not None:
|
||||
|
@ -890,7 +891,7 @@ class Vits(BaseTTS):
|
|||
|
||||
Return Shapes:
|
||||
- model_outputs: :math:`[B, 1, T_wav]`
|
||||
- alignments: :math:`[B, T_seq, T_dec]`
|
||||
- alignments: :math:`[B, T_dec, T_seq]`
|
||||
- z: :math:`[B, C, T_dec]`
|
||||
- z_p: :math:`[B, C, T_dec]`
|
||||
- m_p: :math:`[B, C, T_dec]`
|
||||
|
@ -924,8 +925,8 @@ class Vits(BaseTTS):
|
|||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p]) # [B, 1, T_de, T_en] -> [B, C, T_de]
|
||||
logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de]
|
||||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
@ -982,6 +983,7 @@ class Vits(BaseTTS):
|
|||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
|
@ -1053,6 +1055,7 @@ class Vits(BaseTTS):
|
|||
outputs = {
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"durations": w_ceil,
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
|
@ -1231,7 +1234,7 @@ class Vits(BaseTTS):
|
|||
audios = {f"{name_prefix}/audio": sample_voice}
|
||||
|
||||
alignments = outputs[1]["alignments"]
|
||||
align_img = alignments[0].data.cpu().numpy().T
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures.update(
|
||||
{
|
||||
|
@ -1390,16 +1393,16 @@ class Vits(BaseTTS):
|
|||
"""Compute spectrograms on the device."""
|
||||
ac = self.config.audio
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
wav = self.audio_resampler(batch["waveform"])
|
||||
else:
|
||||
wav = batch["waveform"]
|
||||
|
||||
# compute spectrograms
|
||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
if self.args.encoder_sample_rate:
|
||||
# downsample audio to encoder sample rate
|
||||
wav = self.audio_resampler(batch["waveform"])
|
||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
else:
|
||||
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
# recompute spec with high sampling rate to the loss
|
||||
# recompute spec with vocoder sampling rate
|
||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
# remove extra stft frame
|
||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||
|
|
|
@ -180,10 +180,10 @@ def maximum_path(value, mask):
|
|||
def maximum_path_cython(value, mask):
|
||||
"""Cython optimised version.
|
||||
Shapes:
|
||||
- value: :math:`[B, T_en, T_de]`
|
||||
- mask: :math:`[B, T_en, T_de]`
|
||||
- value: :math:`[B, T_de, T_en]`
|
||||
- mask: :math:`[B, T_de, T_en]`
|
||||
"""
|
||||
value = value * mask
|
||||
# value = value * mask
|
||||
device = value.device
|
||||
dtype = value.dtype
|
||||
value = value.data.cpu().numpy().astype(np.float32)
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
|
||||
import cython
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
|
@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_
|
|||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[x, y-1]
|
||||
v_cur = value[y-1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[x-1, y-1]
|
||||
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
||||
v_prev = value[y-1, x-1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[index, y] = 1
|
||||
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
||||
cdef int b = values.shape[0]
|
||||
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||
cdef int b = paths.shape[0]
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
||||
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
||||
|
|
Loading…
Reference in New Issue