mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset to use numpy transforms
This commit is contained in:
parent
3824838e5d
commit
5cd7fa6228
|
@ -9,7 +9,8 @@ import tqdm
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
|
||||
|
||||
|
||||
# to prevent too many open files error as suggested here
|
||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||
|
@ -37,9 +38,9 @@ def noise_augment_audio(wav):
|
|||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
audio_config: "Coqpit" = None,
|
||||
outputs_per_step: int = 1,
|
||||
compute_linear_spec: bool = False,
|
||||
ap: AudioProcessor = None,
|
||||
samples: List[Dict] = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
compute_f0: bool = False,
|
||||
|
@ -64,12 +65,12 @@ class TTSDataset(Dataset):
|
|||
If you need something different, you can subclass and override.
|
||||
|
||||
Args:
|
||||
audio_config (Coqpit): Audio configuration.
|
||||
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
samples (list): List of dataset samples.
|
||||
|
||||
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
||||
|
@ -115,6 +116,7 @@ class TTSDataset(Dataset):
|
|||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
self.audio_config = audio_config
|
||||
self.batch_group_size = batch_group_size
|
||||
self._samples = samples
|
||||
self.outputs_per_step = outputs_per_step
|
||||
|
@ -126,7 +128,6 @@ class TTSDataset(Dataset):
|
|||
self.max_audio_len = max_audio_len
|
||||
self.min_text_len = min_text_len
|
||||
self.max_text_len = max_text_len
|
||||
self.ap = ap
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
|
@ -146,7 +147,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
if compute_f0:
|
||||
self.f0_dataset = F0Dataset(
|
||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||
self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
|
@ -188,7 +189,7 @@ class TTSDataset(Dataset):
|
|||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
||||
def load_wav(self, filename):
|
||||
waveform = self.ap.load_wav(filename)
|
||||
waveform = load_wav(filename)
|
||||
assert waveform.size > 0
|
||||
return waveform
|
||||
|
||||
|
@ -408,7 +409,7 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
speaker_ids = None
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
|
@ -455,7 +456,7 @@ class TTSDataset(Dataset):
|
|||
# compute linear spectrogram
|
||||
linear = None
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
|
@ -465,13 +466,13 @@ class TTSDataset(Dataset):
|
|||
wav_padded = None
|
||||
if self.return_wav:
|
||||
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||
max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length
|
||||
wav_lengths = torch.LongTensor(wav_lengths)
|
||||
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
||||
for i, w in enumerate(batch["wav"]):
|
||||
mel_length = mel_lengths_adjusted[i]
|
||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||
w = w[: mel_length * self.ap.hop_length]
|
||||
w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge")
|
||||
w = w[: mel_length * self.audio_config.hop_length]
|
||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||
wav_padded.transpose_(1, 2)
|
||||
|
||||
|
@ -654,7 +655,7 @@ class F0Dataset:
|
|||
normalize_f0=True,
|
||||
):
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.audio_config = ap
|
||||
self.verbose = verbose
|
||||
self.cache_path = cache_path
|
||||
self.normalize_f0 = normalize_f0
|
||||
|
@ -750,7 +751,7 @@ class F0Dataset:
|
|||
"""
|
||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
||||
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch.astype(np.float32)
|
||||
|
|
Loading…
Reference in New Issue