Add upsampling_init and upsampling_z methods

This commit is contained in:
Edresson Casanova 2022-04-22 09:03:31 -03:00
parent b3e2c58398
commit ce7138d9d4
1 changed files with 34 additions and 22 deletions

View File

@ -566,6 +566,7 @@ class Vits(BaseTTS):
self.init_multispeaker(config)
self.init_multilingual(config)
self.init_upsampling()
self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale
@ -648,12 +649,6 @@ class Vits(BaseTTS):
use_spectral_norm=self.args.use_spectral_norm_disriminator,
)
if self.args.encoder_sample_rate:
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
)
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
@ -734,6 +729,16 @@ class Vits(BaseTTS):
else:
self.embedded_language_dim = 0
def init_upsampling(self):
"""
Initialize upsampling modules of a model.
"""
if self.args.encoder_sample_rate:
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
)
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
@ -831,6 +836,25 @@ class Vits(BaseTTS):
outputs["loss_duration"] = loss_duration
return outputs, attn
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
spec_segment_size = self.spec_segment_size
if self.args.encoder_sample_rate:
# recompute the slices and spec_segment_size if needed
slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids
spec_segment_size = spec_segment_size * int(self.interpolate_factor)
# interpolate z if needed
if self.args.interpolate_z:
z = torch.nn.functional.interpolate(
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
).squeeze(0)
# recompute the mask if needed
if y_lengths is not None and y_mask is not None:
y_mask = (
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
) # [B, 1, T_dec_resampled]
return z, spec_segment_size, slice_ids, y_mask
def forward( # pylint: disable=dangerous-default-value
self,
x: torch.tensor,
@ -906,16 +930,8 @@ class Vits(BaseTTS):
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
if self.args.encoder_sample_rate:
slice_ids = slice_ids * int(self.interpolate_factor)
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
if self.args.interpolate_z:
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
z_slice = torch.nn.functional.interpolate(
z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest"
).squeeze(0)
else:
spec_segment_size = self.spec_segment_size
# interpolate z if needed
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids)
o = self.waveform_decoder(z_slice, g=g)
@ -1029,12 +1045,8 @@ class Vits(BaseTTS):
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
if self.args.encoder_sample_rate and self.args.interpolate_z:
z = z.unsqueeze(0) # pylint: disable=not-callable
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
y_mask = (
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
) # [B, 1, T_dec_resampled]
# upsampling if needed
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)