diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d8f16e4e..53df94a3 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -9,7 +9,8 @@ import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec + # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -37,9 +38,9 @@ def noise_augment_audio(wav): class TTSDataset(Dataset): def __init__( self, + audio_config: "Coqpit" = None, outputs_per_step: int = 1, compute_linear_spec: bool = False, - ap: AudioProcessor = None, samples: List[Dict] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, @@ -64,12 +65,12 @@ class TTSDataset(Dataset): If you need something different, you can subclass and override. Args: + audio_config (Coqpit): Audio configuration. + outputs_per_step (int): Number of time frames predicted per step. compute_linear_spec (bool): compute linear spectrogram if True. - ap (TTS.tts.utils.AudioProcessor): Audio processor object. - samples (list): List of dataset samples. tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else @@ -115,6 +116,7 @@ class TTSDataset(Dataset): verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() + self.audio_config = audio_config self.batch_group_size = batch_group_size self._samples = samples self.outputs_per_step = outputs_per_step @@ -126,7 +128,6 @@ class TTSDataset(Dataset): self.max_audio_len = max_audio_len self.min_text_len = min_text_len self.max_text_len = max_text_len - self.ap = ap self.phoneme_cache_path = phoneme_cache_path self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping @@ -146,7 +147,7 @@ class TTSDataset(Dataset): if compute_f0: self.f0_dataset = F0Dataset( - self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers ) if self.verbose: @@ -188,7 +189,7 @@ class TTSDataset(Dataset): print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - waveform = self.ap.load_wav(filename) + waveform = load_wav(filename) assert waveform.size > 0 return waveform @@ -408,7 +409,7 @@ class TTSDataset(Dataset): else: speaker_ids = None # compute features - mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] @@ -455,7 +456,7 @@ class TTSDataset(Dataset): # compute linear spectrogram linear = None if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -465,13 +466,13 @@ class TTSDataset(Dataset): wav_padded = None if self.return_wav: wav_lengths = [w.shape[0] for w in batch["wav"]] - max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length wav_lengths = torch.LongTensor(wav_lengths) wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] - w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") - w = w[: mel_length * self.ap.hop_length] + w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.audio_config.hop_length] wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) @@ -654,7 +655,7 @@ class F0Dataset: normalize_f0=True, ): self.samples = samples - self.ap = ap + self.audio_config = ap self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 @@ -750,7 +751,7 @@ class F0Dataset: """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32)