Refactor TTSDataset to use numpy transforms

This commit is contained in:
Eren Gölge 2022-04-19 09:23:18 +02:00
parent 3824838e5d
commit 5cd7fa6228
1 changed files with 15 additions and 14 deletions

View File

@ -9,7 +9,8 @@ import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
@ -37,9 +38,9 @@ def noise_augment_audio(wav):
class TTSDataset(Dataset):
def __init__(
self,
audio_config: "Coqpit" = None,
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
@ -64,12 +65,12 @@ class TTSDataset(Dataset):
If you need something different, you can subclass and override.
Args:
audio_config (Coqpit): Audio configuration.
outputs_per_step (int): Number of time frames predicted per step.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
samples (list): List of dataset samples.
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
@ -115,6 +116,7 @@ class TTSDataset(Dataset):
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.audio_config = audio_config
self.batch_group_size = batch_group_size
self._samples = samples
self.outputs_per_step = outputs_per_step
@ -126,7 +128,6 @@ class TTSDataset(Dataset):
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap
self.phoneme_cache_path = phoneme_cache_path
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
@ -146,7 +147,7 @@ class TTSDataset(Dataset):
if compute_f0:
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
@ -188,7 +189,7 @@ class TTSDataset(Dataset):
print(f"{indent}| > Number of instances : {len(self.samples)}")
def load_wav(self, filename):
waveform = self.ap.load_wav(filename)
waveform = load_wav(filename)
assert waveform.size > 0
return waveform
@ -408,7 +409,7 @@ class TTSDataset(Dataset):
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]]
mel_lengths = [m.shape[1] for m in mel]
@ -455,7 +456,7 @@ class TTSDataset(Dataset):
# compute linear spectrogram
linear = None
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
@ -465,13 +466,13 @@ class TTSDataset(Dataset):
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.audio_config.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
@ -654,7 +655,7 @@ class F0Dataset:
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.audio_config = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
@ -750,7 +751,7 @@ class F0Dataset:
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)