diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 364a007b..33aecbf3 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -566,6 +566,7 @@ class Vits(BaseTTS): self.init_multispeaker(config) self.init_multilingual(config) + self.init_upsampling() self.length_scale = self.args.length_scale self.noise_scale = self.args.noise_scale @@ -648,12 +649,6 @@ class Vits(BaseTTS): use_spectral_norm=self.args.use_spectral_norm_disriminator, ) - if self.args.encoder_sample_rate: - self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate - self.audio_resampler = torchaudio.transforms.Resample( - orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate - ) - def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer or with external `d_vectors` computed from a speaker encoder model. @@ -734,6 +729,16 @@ class Vits(BaseTTS): else: self.embedded_language_dim = 0 + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) + def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} @@ -831,6 +836,25 @@ class Vits(BaseTTS): outputs["loss_duration"] = loss_duration return outputs, attn + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate( + z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest" + ).squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + def forward( # pylint: disable=dangerous-default-value self, x: torch.tensor, @@ -906,16 +930,8 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) - if self.args.encoder_sample_rate: - slice_ids = slice_ids * int(self.interpolate_factor) - spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) - if self.args.interpolate_z: - z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable - z_slice = torch.nn.functional.interpolate( - z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest" - ).squeeze(0) - else: - spec_segment_size = self.spec_segment_size + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) o = self.waveform_decoder(z_slice, g=g) @@ -1029,12 +1045,8 @@ class Vits(BaseTTS): z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale z = self.flow(z_p, y_mask, g=g, reverse=True) - if self.args.encoder_sample_rate and self.args.interpolate_z: - z = z.unsqueeze(0) # pylint: disable=not-callable - z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0) - y_mask = ( - sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) - ) # [B, 1, T_dec_resampled] + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)