mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset to use numpy transforms
This commit is contained in:
parent
3824838e5d
commit
5cd7fa6228
|
@ -9,7 +9,8 @@ import tqdm
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
|
||||||
|
|
||||||
|
|
||||||
# to prevent too many open files error as suggested here
|
# to prevent too many open files error as suggested here
|
||||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||||
|
@ -37,9 +38,9 @@ def noise_augment_audio(wav):
|
||||||
class TTSDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
audio_config: "Coqpit" = None,
|
||||||
outputs_per_step: int = 1,
|
outputs_per_step: int = 1,
|
||||||
compute_linear_spec: bool = False,
|
compute_linear_spec: bool = False,
|
||||||
ap: AudioProcessor = None,
|
|
||||||
samples: List[Dict] = None,
|
samples: List[Dict] = None,
|
||||||
tokenizer: "TTSTokenizer" = None,
|
tokenizer: "TTSTokenizer" = None,
|
||||||
compute_f0: bool = False,
|
compute_f0: bool = False,
|
||||||
|
@ -64,12 +65,12 @@ class TTSDataset(Dataset):
|
||||||
If you need something different, you can subclass and override.
|
If you need something different, you can subclass and override.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
audio_config (Coqpit): Audio configuration.
|
||||||
|
|
||||||
outputs_per_step (int): Number of time frames predicted per step.
|
outputs_per_step (int): Number of time frames predicted per step.
|
||||||
|
|
||||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||||
|
|
||||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
|
||||||
|
|
||||||
samples (list): List of dataset samples.
|
samples (list): List of dataset samples.
|
||||||
|
|
||||||
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
|
||||||
|
@ -115,6 +116,7 @@ class TTSDataset(Dataset):
|
||||||
verbose (bool): Print diagnostic information. Defaults to false.
|
verbose (bool): Print diagnostic information. Defaults to false.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.audio_config = audio_config
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
self._samples = samples
|
self._samples = samples
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
|
@ -126,7 +128,6 @@ class TTSDataset(Dataset):
|
||||||
self.max_audio_len = max_audio_len
|
self.max_audio_len = max_audio_len
|
||||||
self.min_text_len = min_text_len
|
self.min_text_len = min_text_len
|
||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.ap = ap
|
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.speaker_id_mapping = speaker_id_mapping
|
self.speaker_id_mapping = speaker_id_mapping
|
||||||
self.d_vector_mapping = d_vector_mapping
|
self.d_vector_mapping = d_vector_mapping
|
||||||
|
@ -146,7 +147,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
if compute_f0:
|
if compute_f0:
|
||||||
self.f0_dataset = F0Dataset(
|
self.f0_dataset = F0Dataset(
|
||||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -188,7 +189,7 @@ class TTSDataset(Dataset):
|
||||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
waveform = self.ap.load_wav(filename)
|
waveform = load_wav(filename)
|
||||||
assert waveform.size > 0
|
assert waveform.size > 0
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
@ -408,7 +409,7 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
# compute features
|
# compute features
|
||||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]]
|
||||||
|
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
@ -455,7 +456,7 @@ class TTSDataset(Dataset):
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
linear = None
|
linear = None
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
assert mel.shape[1] == linear.shape[1]
|
assert mel.shape[1] == linear.shape[1]
|
||||||
|
@ -465,13 +466,13 @@ class TTSDataset(Dataset):
|
||||||
wav_padded = None
|
wav_padded = None
|
||||||
if self.return_wav:
|
if self.return_wav:
|
||||||
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
||||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length
|
||||||
wav_lengths = torch.LongTensor(wav_lengths)
|
wav_lengths = torch.LongTensor(wav_lengths)
|
||||||
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
||||||
for i, w in enumerate(batch["wav"]):
|
for i, w in enumerate(batch["wav"]):
|
||||||
mel_length = mel_lengths_adjusted[i]
|
mel_length = mel_lengths_adjusted[i]
|
||||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge")
|
||||||
w = w[: mel_length * self.ap.hop_length]
|
w = w[: mel_length * self.audio_config.hop_length]
|
||||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||||
wav_padded.transpose_(1, 2)
|
wav_padded.transpose_(1, 2)
|
||||||
|
|
||||||
|
@ -654,7 +655,7 @@ class F0Dataset:
|
||||||
normalize_f0=True,
|
normalize_f0=True,
|
||||||
):
|
):
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
self.ap = ap
|
self.audio_config = ap
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.cache_path = cache_path
|
self.cache_path = cache_path
|
||||||
self.normalize_f0 = normalize_f0
|
self.normalize_f0 = normalize_f0
|
||||||
|
@ -750,7 +751,7 @@ class F0Dataset:
|
||||||
"""
|
"""
|
||||||
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||||
if not os.path.exists(pitch_file):
|
if not os.path.exists(pitch_file):
|
||||||
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
|
||||||
else:
|
else:
|
||||||
pitch = np.load(pitch_file)
|
pitch = np.load(pitch_file)
|
||||||
return pitch.astype(np.float32)
|
return pitch.astype(np.float32)
|
||||||
|
|
Loading…
Reference in New Issue