diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py index f7ff219a..68a3eec2 100644 --- a/TTS/tts/configs/fastspeech2_config.py +++ b/TTS/tts/configs/fastspeech2_config.py @@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig): base_model: str = "forward_tts" # model specific params - model_args: ForwardTTSArgs = ForwardTTSArgs() + model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True) # multi-speaker settings num_speakers: int = 0 diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index db74186b..df01d663 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -189,6 +189,8 @@ class TTSDataset(Dataset): self._samples = new_samples if hasattr(self, "f0_dataset"): self.f0_dataset.samples = new_samples + if hasattr(self, "energy_dataset"): + self.energy_dataset.samples = new_samples if hasattr(self, "phoneme_dataset"): self.phoneme_dataset.samples = new_samples @@ -856,11 +858,11 @@ class EnergyDataset: def __getitem__(self, idx): item = self.samples[idx] - energy = self.compute_or_load(item["audio_file"]) + energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) if self.normalize_energy: assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" energy = self.normalize(energy) - return {"audio_file": item["audio_file"], "energy": energy} + return {"audio_unique_name": item["audio_unique_name"], "energy": energy} def __len__(self): return len(self.samples) @@ -884,7 +886,7 @@ class EnergyDataset: if self.normalize_energy: computed_data = [tensor for batch in computed_data for tensor in batch] # flatten - energy_mean, energy_std = self.compute_pitch_stats(computed_data) + energy_mean, energy_std = self.compute_energy_stats(computed_data) energy_stats = {"mean": energy_mean, "std": energy_std} np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) @@ -900,7 +902,7 @@ class EnergyDataset: @staticmethod def _compute_and_save_energy(ap, wav_file, energy_file=None): wav = ap.load_wav(wav_file) - energy = calculate_energy(wav) + energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length) if energy_file: np.save(energy_file, energy) return energy @@ -931,11 +933,11 @@ class EnergyDataset: energy[zero_idxs] = 0.0 return energy - def compute_or_load(self, wav_file): + def compute_or_load(self, wav_file, audio_unique_name): """ compute energy and return a numpy array of energy values """ - energy_file = self.create_Energy_file_path(wav_file, self.cache_path) + energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) if not os.path.exists(energy_file): energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) else: @@ -943,14 +945,14 @@ class EnergyDataset: return energy.astype(np.float32) def collate_fn(self, batch): - audio_file = [item["audio_file"] for item in batch] + audio_unique_name = [item["audio_unique_name"] for item in batch] energys = [item["energy"] for item in batch] energy_lens = [len(item["energy"]) for item in batch] energy_lens_max = max(energy_lens) energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) for i, energy_len in enumerate(energy_lens): energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) - return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens} + return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens} def print_logs(self, level: int = 0) -> None: indent = "\t" * level diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index dc53edd0..37a09354 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -183,6 +183,7 @@ class BaseTTS(BaseTrainerModel): attn_mask = batch["attns"] waveform = batch["waveform"] pitch = batch["pitch"] + energy = batch["energy"] language_ids = batch["language_ids"] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -231,6 +232,7 @@ class BaseTTS(BaseTrainerModel): "item_idx": item_idx, "waveform": waveform, "pitch": pitch, + "energy": energy, "language_ids": language_ids, "audio_unique_names": batch["audio_unique_names"], } @@ -313,6 +315,8 @@ class BaseTTS(BaseTrainerModel): compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), samples=samples, ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, diff --git a/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py index d12f8bed..35bda597 100644 --- a/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py +++ b/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py @@ -38,7 +38,7 @@ config = Fastspeech2Config( f0_cache_path="tests/data/ljspeech/f0_cache/", compute_f0=True, compute_energy=True, - energy_cache_path="tests/data/ljspeech/f0_cache/", + energy_cache_path="tests/data/ljspeech/energy_cache/", run_eval=True, test_delay_epochs=-1, epochs=1, diff --git a/tests/tts_tests/test_fastspeech_2_train.py b/tests/tts_tests/test_fastspeech_2_train.py index f54e6351..dd4b07d2 100644 --- a/tests/tts_tests/test_fastspeech_2_train.py +++ b/tests/tts_tests/test_fastspeech_2_train.py @@ -38,7 +38,7 @@ config = Fastspeech2Config( f0_cache_path="tests/data/ljspeech/f0_cache/", compute_f0=True, compute_energy=True, - energy_cache_path="tests/data/ljspeech/f0_cache/", + energy_cache_path="tests/data/ljspeech/energy_cache/", run_eval=True, test_delay_epochs=-1, epochs=1,