diff --git a/TTS/tts/configs/fast_pitch_e2e_config.py b/TTS/tts/configs/fast_pitch_e2e_config.py index f86cf459..21bfc4c3 100644 --- a/TTS/tts/configs/fast_pitch_e2e_config.py +++ b/TTS/tts/configs/fast_pitch_e2e_config.py @@ -102,7 +102,7 @@ class FastPitchE2eConfig(BaseTTSConfig): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ - model: str = "fast_pitch_e2e_hifigan" + model: str = "fast_pitch_e2e" base_model: str = "forward_tts_e2e" # model specific params diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index ed655a32..072da27e 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -9,7 +9,7 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec +from TTS.utils.audio.numpy_transforms import compute_f0, load_wav, wav_to_mel, wav_to_spec # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -647,14 +647,14 @@ class F0Dataset: def __init__( self, samples: Union[List[List], List[Dict]], - ap: "AudioProcessor", + audio_config: "AudioConfig", verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples - self.audio_config = ap + self.audio_config = audio_config self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 @@ -711,9 +711,9 @@ class F0Dataset: return pitch_file @staticmethod - def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) + def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None): + wav = load_wav(wav_file) + pitch = compute_f0(x=wav, pitch_fmax=audio_config.pitch_fmax, hop_length=audio_config.hop_length, sample_rate=audio_config.sample_rate) if pitch_file: np.save(pitch_file, pitch) return pitch diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py index 18b6d7c9..ef7bb155 100644 --- a/TTS/tts/models/forward_tts_e2e.py +++ b/TTS/tts/models/forward_tts_e2e.py @@ -76,10 +76,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset): precompute_num_workers=0, normalize_f0=True, ): - self.audio_config = audio_config super().__init__( samples=samples, - ap=None, + audio_config=audio_config, verbose=verbose, cache_path=cache_path, precompute_num_workers=precompute_num_workers, @@ -87,13 +86,13 @@ class ForwardTTSE2eF0Dataset(F0Dataset): ) @staticmethod - def _compute_and_save_pitch(config, wav_file, pitch_file=None): + def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None): wav, _ = load_audio(wav_file) f0 = compute_f0( - x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax + x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax ) # skip the last F0 value to align with the spectrogram - if wav.shape[1] % config.hop_length != 0: + if wav.shape[1] % audio_config.hop_length != 0: f0 = f0[:-1] if pitch_file: np.save(pitch_file, f0) @@ -105,7 +104,7 @@ class ForwardTTSE2eF0Dataset(F0Dataset): """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) @@ -117,13 +116,12 @@ class ForwardTTSE2eDataset(TTSDataset): compute_f0 = kwargs.pop("compute_f0", False) kwargs["compute_f0"] = False - self.audio_config = kwargs["audio_config"] - del kwargs["audio_config"] - super().__init__(*args, **kwargs) self.compute_f0 = compute_f0 self.pad_id = self.tokenizer.characters.pad_id + self.audio_config = kwargs["audio_config"] + if self.compute_f0: self.f0_dataset = ForwardTTSE2eF0Dataset( audio_config=self.audio_config,