This commit is contained in:
Eren Gölge 2022-05-03 00:34:55 +02:00
parent 76b274e690
commit b78431d7e2
4 changed files with 37 additions and 39 deletions

View File

@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
z_p, logs_q: [B, C, T_de]
m_p, logs_p: [B, C, T_de]
"""
z_p = z_p.float()
logs_q = logs_q.float()

View File

@ -735,6 +735,7 @@ class Vits(BaseTTS):
"""
if self.args.encoder_sample_rate:
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer."
self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
) # pylint: disable=W0201
@ -803,18 +804,18 @@ class Vits(BaseTTS):
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [B, 1, T_en]
logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale]) # [B, T_de, T_en]
logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale]) # [B, T_de, T_en]
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [B, 1, T_en]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en]
# duration predictor
attn_durations = attn.sum(3)
attn_durations = attn.sum(2)
if self.args.use_sdp:
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
@ -845,7 +846,7 @@ class Vits(BaseTTS):
# interpolate z if needed
if self.args.interpolate_z:
z = torch.nn.functional.interpolate(
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear"
).squeeze(0)
# recompute the mask if needed
if y_lengths is not None and y_mask is not None:
@ -890,7 +891,7 @@ class Vits(BaseTTS):
Return Shapes:
- model_outputs: :math:`[B, 1, T_wav]`
- alignments: :math:`[B, T_seq, T_dec]`
- alignments: :math:`[B, T_dec, T_seq]`
- z: :math:`[B, C, T_dec]`
- z_p: :math:`[B, C, T_dec]`
- m_p: :math:`[B, C, T_dec]`
@ -924,8 +925,8 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p]) # [B, 1, T_de, T_en] -> [B, C, T_de]
logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de]
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
@ -982,6 +983,7 @@ class Vits(BaseTTS):
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
@torch.no_grad()
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
): # pylint: disable=dangerous-default-value
@ -1053,6 +1055,7 @@ class Vits(BaseTTS):
outputs = {
"model_outputs": o,
"alignments": attn.squeeze(1),
"durations": w_ceil,
"z": z,
"z_p": z_p,
"m_p": m_p,
@ -1231,7 +1234,7 @@ class Vits(BaseTTS):
audios = {f"{name_prefix}/audio": sample_voice}
alignments = outputs[1]["alignments"]
align_img = alignments[0].data.cpu().numpy().T
align_img = alignments[0].data.cpu().numpy()
figures.update(
{
@ -1390,16 +1393,16 @@ class Vits(BaseTTS):
"""Compute spectrograms on the device."""
ac = self.config.audio
if self.args.encoder_sample_rate:
wav = self.audio_resampler(batch["waveform"])
else:
wav = batch["waveform"]
# compute spectrograms
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
if self.args.encoder_sample_rate:
# downsample audio to encoder sample rate
wav = self.audio_resampler(batch["waveform"])
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
else:
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
if self.args.encoder_sample_rate:
# recompute spec with high sampling rate to the loss
# recompute spec with vocoder sampling rate
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
# remove extra stft frame
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]

View File

@ -180,10 +180,10 @@ def maximum_path(value, mask):
def maximum_path_cython(value, mask):
"""Cython optimised version.
Shapes:
- value: :math:`[B, T_en, T_de]`
- mask: :math:`[B, T_en, T_de]`
- value: :math:`[B, T_de, T_en]`
- mask: :math:`[B, T_de, T_en]`
"""
value = value * mask
# value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)

View File

@ -1,14 +1,10 @@
import numpy as np
cimport cython
cimport numpy as np
import cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_
if x == y:
v_cur = max_neg_val
else:
v_cur = value[x, y-1]
v_cur = value[y-1, x]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[x-1, y-1]
value[x, y] = max(v_cur, v_prev) + value[x, y]
v_prev = value[y-1, x-1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[index, y] = 1
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
path[y, index] = 1
if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
cdef int b = values.shape[0]
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])