diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index fbf71e1b..313943dd 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -151,7 +151,7 @@ class VitsConfig(BaseTTSConfig): d_vector_dim: int = None # dataset configs - compute_f0: bool = False + compute_pitch: bool = False f0_cache_path: str = None def __post_init__(self): diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index b7d87237..b32ca1e3 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -646,6 +646,7 @@ class VitsGeneratorLoss(nn.Module): if loss_spk_reversal_classifier is not None: loss += loss_spk_reversal_classifier return_dict["loss_spk_reversal_classifier"] = loss_spk_reversal_classifier + if pitch_loss is not None: pitch_loss = pitch_loss * self.pitch_loss_alpha loss += pitch_loss diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0eb91719..cc97f81e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -189,7 +189,7 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm spec = amp_to_db(spec) return spec -def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray: +def compute_pitch(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.ndarray: """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. Args: @@ -217,8 +217,8 @@ def compute_f0(x: np.ndarray, sample_rate, hop_length, pitch_fmax=800.0) -> np.n class VITSF0Dataset(F0Dataset): def __init__(self, config, *args, **kwargs): + self.audio_config = config.audio super().__init__(*args, **kwargs) - self.config = config def compute_or_load(self, wav_file): """ @@ -226,15 +226,15 @@ class VITSF0Dataset(F0Dataset): """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(wav_file, pitch_file) + pitch = self._compute_and_save_pitch(wav_file, self.audio_config.sample_rate, self.audio_config.hop_length, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) - def _compute_and_save_pitch(self, wav_file, pitch_file=None): - print(wav_file, pitch_file) + @staticmethod + def _compute_and_save_pitch(wav_file, sample_rate, hop_length, pitch_file=None): wav, _ = load_audio(wav_file) - pitch = compute_f0(wav.squeeze().numpy(), self.config.audio.sample_rate, self.config.audio.hop_length) + pitch = compute_pitch(wav.squeeze().numpy(), sample_rate, hop_length) if pitch_file: np.save(pitch_file, pitch) return pitch @@ -242,11 +242,14 @@ class VITSF0Dataset(F0Dataset): class VitsDataset(TTSDataset): - def __init__(self, config, *args, **kwargs): + def __init__(self, config, compute_pitch=False, *args, **kwargs): super().__init__(*args, **kwargs) self.pad_id = self.tokenizer.characters.pad_id + self.compute_pitch = compute_pitch + - self.f0_dataset = VITSF0Dataset(config, + if self.compute_pitch: + self.f0_dataset = VITSF0Dataset(config, samples=self.samples, ap=self.ap, cache_path=self.f0_cache_path, precompute_num_workers=self.precompute_num_workers ) @@ -261,7 +264,7 @@ class VitsDataset(TTSDataset): # get f0 values f0 = None - if self.compute_f0: + if self.compute_pitch: f0 = self.get_f0(idx)["f0"] # after phonemization the text length may change @@ -335,7 +338,7 @@ class VitsDataset(TTSDataset): # format F0 - if self.compute_f0: + if self.compute_pitch: pitch = prepare_data(batch["pitch"]) pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: @@ -592,6 +595,7 @@ class VitsArgs(Coqpit): prosody_embedding_dim: int = 0 prosody_encoder_num_heads: int = 1 prosody_encoder_num_tokens: int = 5 + use_prosody_enc_spk_reversal_classifier: bool = True # Pitch predictor use_pitch: bool = False @@ -739,12 +743,6 @@ class Vits(BaseTTS): self.args.pitch_predictor_dropout_p, cond_channels=dp_cond_embedding_dim, ) - self.pitch_emb = nn.Conv1d( - 1, - self.args.hidden_channels, - kernel_size=self.args.pitch_embedding_kernel_size, - padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), - ) if self.args.use_prosody_encoder: self.prosody_encoder = GST( @@ -753,11 +751,12 @@ class Vits(BaseTTS): num_style_tokens=self.args.prosody_encoder_num_tokens, gst_embedding_dim=self.args.prosody_embedding_dim, ) - self.speaker_reversal_classifier = ReversalClassifier( - in_channels=self.args.prosody_embedding_dim, - out_channels=self.num_speakers, - hidden_channels=256, - ) + if self.args.use_prosody_enc_spk_reversal_classifier: + self.speaker_reversal_classifier = ReversalClassifier( + in_channels=self.args.prosody_embedding_dim, + out_channels=self.num_speakers, + hidden_channels=256, + ) self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, @@ -1020,10 +1019,9 @@ class Vits(BaseTTS): x_mask, g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp ) - print(o_pitch.shape, pitch.shape, dr.shape) + avg_pitch = average_over_durations(pitch, dr.squeeze()) - o_pitch_emb = self.pitch_emb(avg_pitch) - pitch_loss = torch.sum(torch.sum((o_pitch_emb - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) + pitch_loss = torch.sum(torch.sum((avg_pitch - o_pitch) ** 2, [1, 2]) / torch.sum(x_mask)) return pitch_loss def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): @@ -1137,7 +1135,8 @@ class Vits(BaseTTS): l_pros_speaker = None if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) - _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) + if self.args.use_prosody_enc_spk_reversal_classifier: + _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) @@ -1160,6 +1159,7 @@ class Vits(BaseTTS): outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb) + pitch_loss = None if self.args.use_pitch: pitch_loss = self.forward_pitch_predictor(x, x_mask, pitch, attn.sum(3), g_dp) @@ -1781,7 +1781,7 @@ class Vits(BaseTTS): verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, - compute_f0=config.get("compute_f0", False), + compute_pitch=config.get("compute_pitch", False), f0_cache_path=config.get("f0_cache_path", None), ) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py index 9b13d501..30d9f0f6 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder_with_pitch_predictor.py @@ -25,7 +25,7 @@ config = VitsConfig( epochs=1, print_step=1, print_eval=True, - compute_f0=True, + compute_pitch=True, f0_cache_path="tests/data/ljspeech/f0_cache/", test_sentences=[ ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],