diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py index ef7bb155..5bda50de 100644 --- a/TTS/tts/models/forward_tts_e2e.py +++ b/TTS/tts/models/forward_tts_e2e.py @@ -89,7 +89,10 @@ class ForwardTTSE2eF0Dataset(F0Dataset): def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None): wav, _ = load_audio(wav_file) f0 = compute_f0( - x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax + x=wav.numpy()[0], + sample_rate=audio_config.sample_rate, + hop_length=audio_config.hop_length, + pitch_fmax=audio_config.pitch_fmax, ) # skip the last F0 value to align with the spectrogram if wav.shape[1] % audio_config.hop_length != 0: @@ -104,7 +107,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset): """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file) + pitch = self._compute_and_save_pitch( + audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file + ) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) @@ -300,6 +305,9 @@ class ForwardTTSE2eArgs(ForwardTTSArgs): upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) upsample_initial_channel_decoder: int = 512 upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + # discriminator + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) # multi-speaker params use_speaker_embedding: bool = False num_speakers: int = 0 @@ -359,7 +367,18 @@ class ForwardTTSE2e(BaseTTSE2E): # use Vits Discriminator for limiting VRAM use if self.args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_discriminator) + self.disc = VitsDiscriminator( + use_spectral_norm=self.args.use_spectral_norm_discriminator, + periods=self.args.periods_discriminator, + upsampling_rates=self.args.upsampling_rates_discriminator, + ) + + # def check_model_args(self): + # upsample_rate = torch.prod(torch.as_tensor(self.args.upsample_rates_decoder)).item() + # if s + # assert ( + # upsample_rate == self.config.audio.hop_length + # ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" def init_multispeaker(self, config: Coqpit): """Init for multi-speaker training. @@ -440,9 +459,10 @@ class ForwardTTSE2e(BaseTTSE2E): let_short_samples=True, pad_short=True, ) + vocoder_output = self.waveform_decoder( x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices, - g=encoder_outputs["g"], + g=encoder_outputs["spk_emb"], ) wav_seg = segment( waveform, @@ -461,7 +481,7 @@ class ForwardTTSE2e(BaseTTSE2E): def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True) o_en_ex = encoder_outputs["o_en_ex"] - vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"]) + vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["spk_emb"]) model_outputs = {**encoder_outputs} model_outputs["model_outputs"] = vocoder_output return model_outputs @@ -860,9 +880,11 @@ class ForwardTTSE2e(BaseTTSE2E): center=False, ) - assert ( - batch["pitch"].shape[2] == batch["mel_input"].shape[2] - ), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}" + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() # zero the padding frames @@ -990,9 +1012,7 @@ class ForwardTTSE2e(BaseTTSE2E): # language_manager = LanguageManager.init_from_config(config) return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager) - def load_checkpoint( - self, config, checkpoint_path, eval=False - ): + def load_checkpoint(self, config, checkpoint_path, eval=False): """Load model from a checkpoint created by the 👟""" # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device("cpu")) @@ -1003,11 +1023,7 @@ class ForwardTTSE2e(BaseTTSE2E): def get_state_dict(self): """Custom state dict of the model with all the necessary components for inference.""" - save_state = { - "config": self.config.to_dict(), - "args": self.args.to_dict(), - "model": self.state_dict - } + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} if hasattr(self, "emb_g"): save_state["speaker_ids"] = self.speaker_manager.speaker_ids