mirror of https://github.com/coqui-ai/TTS.git
Fix MAS
This commit is contained in:
parent
76b274e690
commit
b78431d7e2
|
@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||||
"""
|
"""
|
||||||
z_p, logs_q: [b, h, t_t]
|
z_p, logs_q: [B, C, T_de]
|
||||||
m_p, logs_p: [b, h, t_t]
|
m_p, logs_p: [B, C, T_de]
|
||||||
"""
|
"""
|
||||||
z_p = z_p.float()
|
z_p = z_p.float()
|
||||||
logs_q = logs_q.float()
|
logs_q = logs_q.float()
|
||||||
|
|
|
@ -735,6 +735,7 @@ class Vits(BaseTTS):
|
||||||
"""
|
"""
|
||||||
if self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||||
|
assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer."
|
||||||
self.audio_resampler = torchaudio.transforms.Resample(
|
self.audio_resampler = torchaudio.transforms.Resample(
|
||||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||||
) # pylint: disable=W0201
|
) # pylint: disable=W0201
|
||||||
|
@ -803,18 +804,18 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * logs_p)
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [B, 1, T_en]
|
||||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
|
logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale]) # [B, T_de, T_en]
|
||||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale]) # [B, T_de, T_en]
|
||||||
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [B, 1, T_en]
|
||||||
logp = logp2 + logp3 + logp1 + logp4
|
logp = logp2 + logp3 + logp1 + logp4
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en]
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
attn_durations = attn.sum(3)
|
attn_durations = attn.sum(2)
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
loss_duration = self.duration_predictor(
|
loss_duration = self.duration_predictor(
|
||||||
x.detach() if self.args.detach_dp_input else x,
|
x.detach() if self.args.detach_dp_input else x,
|
||||||
|
@ -845,7 +846,7 @@ class Vits(BaseTTS):
|
||||||
# interpolate z if needed
|
# interpolate z if needed
|
||||||
if self.args.interpolate_z:
|
if self.args.interpolate_z:
|
||||||
z = torch.nn.functional.interpolate(
|
z = torch.nn.functional.interpolate(
|
||||||
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
|
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear"
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
# recompute the mask if needed
|
# recompute the mask if needed
|
||||||
if y_lengths is not None and y_mask is not None:
|
if y_lengths is not None and y_mask is not None:
|
||||||
|
@ -890,7 +891,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
Return Shapes:
|
Return Shapes:
|
||||||
- model_outputs: :math:`[B, 1, T_wav]`
|
- model_outputs: :math:`[B, 1, T_wav]`
|
||||||
- alignments: :math:`[B, T_seq, T_dec]`
|
- alignments: :math:`[B, T_dec, T_seq]`
|
||||||
- z: :math:`[B, C, T_dec]`
|
- z: :math:`[B, C, T_dec]`
|
||||||
- z_p: :math:`[B, C, T_dec]`
|
- z_p: :math:`[B, C, T_dec]`
|
||||||
- m_p: :math:`[B, C, T_dec]`
|
- m_p: :math:`[B, C, T_dec]`
|
||||||
|
@ -924,8 +925,8 @@ class Vits(BaseTTS):
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p]) # [B, 1, T_de, T_en] -> [B, C, T_de]
|
||||||
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de]
|
||||||
|
|
||||||
# select a random feature segment for the waveform decoder
|
# select a random feature segment for the waveform decoder
|
||||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||||
|
@ -982,6 +983,7 @@ class Vits(BaseTTS):
|
||||||
return aux_input["x_lengths"]
|
return aux_input["x_lengths"]
|
||||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
|
@ -1053,6 +1055,7 @@ class Vits(BaseTTS):
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": o,
|
"model_outputs": o,
|
||||||
"alignments": attn.squeeze(1),
|
"alignments": attn.squeeze(1),
|
||||||
|
"durations": w_ceil,
|
||||||
"z": z,
|
"z": z,
|
||||||
"z_p": z_p,
|
"z_p": z_p,
|
||||||
"m_p": m_p,
|
"m_p": m_p,
|
||||||
|
@ -1231,7 +1234,7 @@ class Vits(BaseTTS):
|
||||||
audios = {f"{name_prefix}/audio": sample_voice}
|
audios = {f"{name_prefix}/audio": sample_voice}
|
||||||
|
|
||||||
alignments = outputs[1]["alignments"]
|
alignments = outputs[1]["alignments"]
|
||||||
align_img = alignments[0].data.cpu().numpy().T
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
figures.update(
|
figures.update(
|
||||||
{
|
{
|
||||||
|
@ -1390,16 +1393,16 @@ class Vits(BaseTTS):
|
||||||
"""Compute spectrograms on the device."""
|
"""Compute spectrograms on the device."""
|
||||||
ac = self.config.audio
|
ac = self.config.audio
|
||||||
|
|
||||||
if self.args.encoder_sample_rate:
|
|
||||||
wav = self.audio_resampler(batch["waveform"])
|
|
||||||
else:
|
|
||||||
wav = batch["waveform"]
|
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
|
if self.args.encoder_sample_rate:
|
||||||
|
# downsample audio to encoder sample rate
|
||||||
|
wav = self.audio_resampler(batch["waveform"])
|
||||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||||
|
else:
|
||||||
|
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||||
|
|
||||||
if self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
# recompute spec with high sampling rate to the loss
|
# recompute spec with vocoder sampling rate
|
||||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||||
# remove extra stft frame
|
# remove extra stft frame
|
||||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||||
|
|
|
@ -180,10 +180,10 @@ def maximum_path(value, mask):
|
||||||
def maximum_path_cython(value, mask):
|
def maximum_path_cython(value, mask):
|
||||||
"""Cython optimised version.
|
"""Cython optimised version.
|
||||||
Shapes:
|
Shapes:
|
||||||
- value: :math:`[B, T_en, T_de]`
|
- value: :math:`[B, T_de, T_en]`
|
||||||
- mask: :math:`[B, T_en, T_de]`
|
- mask: :math:`[B, T_de, T_en]`
|
||||||
"""
|
"""
|
||||||
value = value * mask
|
# value = value * mask
|
||||||
device = value.device
|
device = value.device
|
||||||
dtype = value.dtype
|
dtype = value.dtype
|
||||||
value = value.data.cpu().numpy().astype(np.float32)
|
value = value.data.cpu().numpy().astype(np.float32)
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
import numpy as np
|
import cython
|
||||||
|
|
||||||
cimport cython
|
|
||||||
cimport numpy as np
|
|
||||||
|
|
||||||
from cython.parallel import prange
|
from cython.parallel import prange
|
||||||
|
|
||||||
|
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||||
cdef int x
|
cdef int x
|
||||||
cdef int y
|
cdef int y
|
||||||
cdef float v_prev
|
cdef float v_prev
|
||||||
|
@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_
|
||||||
if x == y:
|
if x == y:
|
||||||
v_cur = max_neg_val
|
v_cur = max_neg_val
|
||||||
else:
|
else:
|
||||||
v_cur = value[x, y-1]
|
v_cur = value[y-1, x]
|
||||||
if x == 0:
|
if x == 0:
|
||||||
if y == 0:
|
if y == 0:
|
||||||
v_prev = 0.
|
v_prev = 0.
|
||||||
else:
|
else:
|
||||||
v_prev = max_neg_val
|
v_prev = max_neg_val
|
||||||
else:
|
else:
|
||||||
v_prev = value[x-1, y-1]
|
v_prev = value[y-1, x-1]
|
||||||
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
value[y, x] += max(v_prev, v_cur)
|
||||||
|
|
||||||
for y in range(t_y - 1, -1, -1):
|
for y in range(t_y - 1, -1, -1):
|
||||||
path[index, y] = 1
|
path[y, index] = 1
|
||||||
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
|
||||||
index = index - 1
|
index = index - 1
|
||||||
|
|
||||||
|
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||||
cdef int b = values.shape[0]
|
cdef int b = paths.shape[0]
|
||||||
|
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in prange(b, nogil=True):
|
for i in prange(b, nogil=True):
|
||||||
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
||||||
|
|
Loading…
Reference in New Issue