From b3e2c58398900912dad406b41b1fe36295905ed8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 22 Apr 2022 07:57:27 -0300 Subject: [PATCH] Rename TTS_part_sample_rate to encoder_sample_rate --- TTS/tts/models/vits.py | 32 +++++++++---------- ...train_upsampling_interpolation_approach.py | 2 +- ...r_emb_train_upsampling_vocoder_approach.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9e566e87..364a007b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -455,14 +455,14 @@ class VitsArgs(Coqpit): freeze_waveform_decoder (bool): Freeze the waveform decoder weigths during training. Defaults to False. - TTS_part_sample_rate (int): + encoder_sample_rate (int): If not None this sample rate will be used for training the Posterior Encoder, flow, text_encoder and duration predictor. The decoder part (vocoder) will be trained with the `config.audio.sample_rate`. Defaults to None. interpolate_z (bool): - If `TTS_part_sample_rate` not None and this parameter True the nearest interpolation - will be used to upsampling the latent variable z with the sampling rate `TTS_part_sample_rate` + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` to the `config.audio.sample_rate`. If it is False you will need to add extra `upsample_rates_decoder` to match the shape. Defaults to True. @@ -521,7 +521,7 @@ class VitsArgs(Coqpit): freeze_PE: bool = False freeze_flow_decoder: bool = False freeze_waveform_decoder: bool = False - TTS_part_sample_rate: int = None + encoder_sample_rate: int = None interpolate_z: bool = True @@ -648,10 +648,10 @@ class Vits(BaseTTS): use_spectral_norm=self.args.use_spectral_norm_disriminator, ) - if self.args.TTS_part_sample_rate: - self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate self.audio_resampler = torchaudio.transforms.Resample( - orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate ) def init_multispeaker(self, config: Coqpit): @@ -906,7 +906,7 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) - if self.args.TTS_part_sample_rate: + if self.args.encoder_sample_rate: slice_ids = slice_ids * int(self.interpolate_factor) spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) if self.args.interpolate_z: @@ -1029,7 +1029,7 @@ class Vits(BaseTTS): z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale z = self.flow(z_p, y_mask, g=g, reverse=True) - if self.args.TTS_part_sample_rate and self.args.interpolate_z: + if self.args.encoder_sample_rate and self.args.interpolate_z: z = z.unsqueeze(0) # pylint: disable=not-callable z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0) y_mask = ( @@ -1155,7 +1155,7 @@ class Vits(BaseTTS): # compute melspec segment with autocast(enabled=False): - if self.args.TTS_part_sample_rate: + if self.args.encoder_sample_rate: spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) else: spec_segment_size = self.spec_segment_size @@ -1370,7 +1370,7 @@ class Vits(BaseTTS): """Compute spectrograms on the device.""" ac = self.config.audio - if self.args.TTS_part_sample_rate: + if self.args.encoder_sample_rate: wav = self.audio_resampler(batch["waveform"]) else: wav = batch["waveform"] @@ -1378,7 +1378,7 @@ class Vits(BaseTTS): # compute spectrograms batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) - if self.args.TTS_part_sample_rate: + if self.args.encoder_sample_rate: # recompute spec with high sampling rate to the loss spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) else: @@ -1393,14 +1393,14 @@ class Vits(BaseTTS): fmax=ac.mel_fmax, ) - if not self.args.TTS_part_sample_rate: + if not self.args.encoder_sample_rate: assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" # compute spectrogram frame lengths batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() - if not self.args.TTS_part_sample_rate: + if not self.args.encoder_sample_rate: assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 # zero the padding frames @@ -1518,7 +1518,7 @@ class Vits(BaseTTS): # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} - if self.args.TTS_part_sample_rate is not None and eval: + if self.args.encoder_sample_rate is not None and eval: # audio resampler is not used in inference time self.audio_resampler = None @@ -1549,7 +1549,7 @@ class Vits(BaseTTS): from TTS.utils.audio import AudioProcessor upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() - if not config.model_args.TTS_part_sample_rate: + if not config.model_args.encoder_sample_rate: assert ( upsample_rate == config.audio.hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py index 9d9e372c..c279d004 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py +++ b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_interpolation_approach.py @@ -42,7 +42,7 @@ config.model_args.d_vector_dim = 256 # test upsample interpolation approach -config.model_args.TTS_part_sample_rate = 11025 +config.model_args.encoder_sample_rate = 11025 config.model_args.interpolate_z = True config.model_args.upsample_rates_decoder = [8, 8, 2, 2] config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7] diff --git a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py index 758aa4a1..35248b4c 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py +++ b/tests/tts_tests/test_vits_speaker_emb_train_upsampling_vocoder_approach.py @@ -42,7 +42,7 @@ config.model_args.d_vector_dim = 256 # test upsample -config.model_args.TTS_part_sample_rate = 11025 +config.model_args.encoder_sample_rate = 11025 config.model_args.interpolate_z = False config.model_args.upsample_rates_decoder = [8, 8, 4, 2] config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7]