Update TTSDataset

This commit is contained in:
Eren Gölge 2022-02-20 11:55:40 +01:00
parent 750903d2ba
commit ff23dce081
1 changed files with 13 additions and 82 deletions

View File

@ -37,10 +37,10 @@ def noise_augment_audio(wav):
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step: int,
compute_linear_spec: bool,
ap: AudioProcessor,
samples: List[Dict],
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
f0_cache_path: str = None,
@ -118,7 +118,6 @@ class TTSDataset(Dataset):
self.batch_group_size = batch_group_size
self._samples = samples
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
@ -153,6 +152,15 @@ class TTSDataset(Dataset):
if self.verbose:
self.print_logs()
@property
def lengths(self):
lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
lens.append(audio_len)
return lens
@property
def samples(self):
return self._samples
@ -763,80 +771,3 @@ class F0Dataset:
print(f"{indent}| > Number of instances : {len(self.samples)}")
# if __name__ == "__main__":
# from torch.utils.data import DataLoader
# from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
# from TTS.tts.datasets import load_tts_samples
# from TTS.tts.utils.text.characters import IPAPhonemes
# from TTS.tts.utils.text.phonemizers import ESpeak
# dataset_config = BaseDatasetConfig(
# name="ljspeech",
# meta_file_train="metadata.csv",
# path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1",
# )
# train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# samples = train_samples + eval_samples
# phonemizer = ESpeak(language="en-us")
# tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer)
# # ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests")
# # ph_dataset.precompute(num_workers=4)
# # dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn)
# # for batch in dataloader:
# # print(batch)
# # break
# audio_config = BaseAudioConfig(
# sample_rate=22050,
# win_length=1024,
# hop_length=256,
# num_mels=80,
# preemphasis=0.0,
# ref_level_db=20,
# log_func="np.log",
# do_trim_silence=True,
# trim_db=45,
# mel_fmin=0,
# mel_fmax=8000,
# spec_gain=1.0,
# signal_norm=False,
# do_amp_to_db_linear=False,
# )
# ap = AudioProcessor.init_from_config(audio_config)
# # f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4)
# # dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn)
# # for batch in dataloader:
# # print(batch)
# # breakpoint()
# # break
# dataset = TTSDataset(
# outputs_per_step=1,
# compute_linear_spec=False,
# samples=samples,
# ap=ap,
# return_wav=False,
# batch_group_size=0,
# min_seq_len=0,
# max_seq_len=500,
# use_noise_augment=False,
# verbose=True,
# speaker_id_mapping=None,
# d_vector_mapping=None,
# compute_f0=True,
# f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests",
# tokenizer=tokenizer,
# phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests",
# precompute_num_workers=4,
# )
# dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
# for batch in dataloader:
# print(batch)
# break