This commit is contained in:
Eren Gölge 2022-05-03 00:34:55 +02:00
parent 76b274e690
commit b78431d7e2
4 changed files with 37 additions and 39 deletions

View File

@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod @staticmethod
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
""" """
z_p, logs_q: [b, h, t_t] z_p, logs_q: [B, C, T_de]
m_p, logs_p: [b, h, t_t] m_p, logs_p: [B, C, T_de]
""" """
z_p = z_p.float() z_p = z_p.float()
logs_q = logs_q.float() logs_q = logs_q.float()

View File

@ -735,6 +735,7 @@ class Vits(BaseTTS):
""" """
if self.args.encoder_sample_rate: if self.args.encoder_sample_rate:
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer."
self.audio_resampler = torchaudio.transforms.Resample( self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
) # pylint: disable=W0201 ) # pylint: disable=W0201
@ -803,18 +804,18 @@ class Vits(BaseTTS):
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
# find the alignment path # find the alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * logs_p) o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [B, 1, T_en]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale]) # [B, T_de, T_en]
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale]) # [B, T_de, T_en]
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True) # [B, 1, T_en]
logp = logp2 + logp3 + logp1 + logp4 logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en]
# duration predictor # duration predictor
attn_durations = attn.sum(3) attn_durations = attn.sum(2)
if self.args.use_sdp: if self.args.use_sdp:
loss_duration = self.duration_predictor( loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
@ -845,7 +846,7 @@ class Vits(BaseTTS):
# interpolate z if needed # interpolate z if needed
if self.args.interpolate_z: if self.args.interpolate_z:
z = torch.nn.functional.interpolate( z = torch.nn.functional.interpolate(
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest" z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear"
).squeeze(0) ).squeeze(0)
# recompute the mask if needed # recompute the mask if needed
if y_lengths is not None and y_mask is not None: if y_lengths is not None and y_mask is not None:
@ -890,7 +891,7 @@ class Vits(BaseTTS):
Return Shapes: Return Shapes:
- model_outputs: :math:`[B, 1, T_wav]` - model_outputs: :math:`[B, 1, T_wav]`
- alignments: :math:`[B, T_seq, T_dec]` - alignments: :math:`[B, T_dec, T_seq]`
- z: :math:`[B, C, T_dec]` - z: :math:`[B, C, T_dec]`
- z_p: :math:`[B, C, T_dec]` - z_p: :math:`[B, C, T_dec]`
- m_p: :math:`[B, C, T_dec]` - m_p: :math:`[B, C, T_dec]`
@ -924,8 +925,8 @@ class Vits(BaseTTS):
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
# expand prior # expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p]) # [B, 1, T_de, T_en] -> [B, C, T_de]
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de]
# select a random feature segment for the waveform decoder # select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
@ -982,6 +983,7 @@ class Vits(BaseTTS):
return aux_input["x_lengths"] return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device) return torch.tensor(x.shape[1:2]).to(x.device)
@torch.no_grad()
def inference( def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
@ -1053,6 +1055,7 @@ class Vits(BaseTTS):
outputs = { outputs = {
"model_outputs": o, "model_outputs": o,
"alignments": attn.squeeze(1), "alignments": attn.squeeze(1),
"durations": w_ceil,
"z": z, "z": z,
"z_p": z_p, "z_p": z_p,
"m_p": m_p, "m_p": m_p,
@ -1231,7 +1234,7 @@ class Vits(BaseTTS):
audios = {f"{name_prefix}/audio": sample_voice} audios = {f"{name_prefix}/audio": sample_voice}
alignments = outputs[1]["alignments"] alignments = outputs[1]["alignments"]
align_img = alignments[0].data.cpu().numpy().T align_img = alignments[0].data.cpu().numpy()
figures.update( figures.update(
{ {
@ -1390,16 +1393,16 @@ class Vits(BaseTTS):
"""Compute spectrograms on the device.""" """Compute spectrograms on the device."""
ac = self.config.audio ac = self.config.audio
if self.args.encoder_sample_rate:
wav = self.audio_resampler(batch["waveform"])
else:
wav = batch["waveform"]
# compute spectrograms # compute spectrograms
if self.args.encoder_sample_rate:
# downsample audio to encoder sample rate
wav = self.audio_resampler(batch["waveform"])
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
else:
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
if self.args.encoder_sample_rate: if self.args.encoder_sample_rate:
# recompute spec with high sampling rate to the loss # recompute spec with vocoder sampling rate
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
# remove extra stft frame # remove extra stft frame
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]

View File

@ -180,10 +180,10 @@ def maximum_path(value, mask):
def maximum_path_cython(value, mask): def maximum_path_cython(value, mask):
"""Cython optimised version. """Cython optimised version.
Shapes: Shapes:
- value: :math:`[B, T_en, T_de]` - value: :math:`[B, T_de, T_en]`
- mask: :math:`[B, T_en, T_de]` - mask: :math:`[B, T_de, T_en]`
""" """
value = value * mask # value = value * mask
device = value.device device = value.device
dtype = value.dtype dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32) value = value.data.cpu().numpy().astype(np.float32)

View File

@ -1,14 +1,10 @@
import numpy as np import cython
cimport cython
cimport numpy as np
from cython.parallel import prange from cython.parallel import prange
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x cdef int x
cdef int y cdef int y
cdef float v_prev cdef float v_prev
@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_
if x == y: if x == y:
v_cur = max_neg_val v_cur = max_neg_val
else: else:
v_cur = value[x, y-1] v_cur = value[y-1, x]
if x == 0: if x == 0:
if y == 0: if y == 0:
v_prev = 0. v_prev = 0.
else: else:
v_prev = max_neg_val v_prev = max_neg_val
else: else:
v_prev = value[x-1, y-1] v_prev = value[y-1, x-1]
value[x, y] = max(v_cur, v_prev) + value[x, y] value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1): for y in range(t_y - 1, -1, -1):
path[index, y] = 1 path[y, index] = 1
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
index = index - 1 index = index - 1
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = values.shape[0] cdef int b = paths.shape[0]
cdef int i cdef int i
for i in prange(b, nogil=True): for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])