mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset ⚡️
This commit is contained in:
parent
4597d4e5b6
commit
176b712c1a
|
@ -2,7 +2,7 @@ import collections
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,6 +14,24 @@ from TTS.tts.utils.text import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sample(item):
|
||||||
|
language_name = None
|
||||||
|
attn_file = None
|
||||||
|
if len(item) == 5:
|
||||||
|
text, wav_file, speaker_name, language_name, attn_file = item
|
||||||
|
elif len(item) == 4:
|
||||||
|
text, wav_file, speaker_name, language_name = item
|
||||||
|
elif len(item) == 3:
|
||||||
|
text, wav_file, speaker_name = item
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Dataset cannot parse the sample.")
|
||||||
|
return text, wav_file, speaker_name, language_name, attn_file
|
||||||
|
|
||||||
|
|
||||||
|
def noise_augment_audio(wav):
|
||||||
|
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
||||||
|
|
||||||
|
|
||||||
class TTSDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -26,9 +44,12 @@ class TTSDataset(Dataset):
|
||||||
f0_cache_path: str = None,
|
f0_cache_path: str = None,
|
||||||
return_wav: bool = False,
|
return_wav: bool = False,
|
||||||
batch_group_size: int = 0,
|
batch_group_size: int = 0,
|
||||||
min_seq_len: int = 0,
|
min_text_len: int = 0,
|
||||||
max_seq_len: int = float("inf"),
|
max_text_len: int = float("inf"),
|
||||||
|
min_audio_len: int = 0,
|
||||||
|
max_audio_len: int = float("inf"),
|
||||||
phoneme_cache_path: str = None,
|
phoneme_cache_path: str = None,
|
||||||
|
precompute_num_workers: int = 0,
|
||||||
speaker_id_mapping: Dict = None,
|
speaker_id_mapping: Dict = None,
|
||||||
d_vector_mapping: Dict = None,
|
d_vector_mapping: Dict = None,
|
||||||
language_id_mapping: Dict = None,
|
language_id_mapping: Dict = None,
|
||||||
|
@ -37,7 +58,7 @@ class TTSDataset(Dataset):
|
||||||
):
|
):
|
||||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||||
|
|
||||||
If you need something different, you can inherit and override.
|
If you need something different, you can subclass and override.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs_per_step (int): Number of time frames predicted per step.
|
outputs_per_step (int): Number of time frames predicted per step.
|
||||||
|
@ -61,17 +82,24 @@ class TTSDataset(Dataset):
|
||||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||||
batch. Set 0 to disable. Defaults to 0.
|
batch. Set 0 to disable. Defaults to 0.
|
||||||
|
|
||||||
min_seq_len (int): Minimum input sequence length to be processed
|
min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
|
||||||
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
|
Defaults to 0.
|
||||||
minimum input length due to its architecture. Defaults to 0.
|
|
||||||
|
|
||||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
|
||||||
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
Defaults to float("inf").
|
||||||
RNN layers are sensitive to input length. Defaults to `Inf`.
|
|
||||||
|
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
|
||||||
|
Defaults to 0.
|
||||||
|
|
||||||
|
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
|
||||||
|
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
|
||||||
|
this value if you encounter an OOM error in training. Defaults to float("inf").
|
||||||
|
|
||||||
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
||||||
separate file. Defaults to None.
|
separate file. Defaults to None.
|
||||||
|
|
||||||
|
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.
|
||||||
|
|
||||||
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||||
embedding layer. Defaults to None.
|
embedding layer. Defaults to None.
|
||||||
|
|
||||||
|
@ -83,15 +111,17 @@ class TTSDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
self.items = meta_data
|
self._samples = meta_data
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
self.return_wav = return_wav
|
self.return_wav = return_wav
|
||||||
self.compute_f0 = compute_f0
|
self.compute_f0 = compute_f0
|
||||||
self.f0_cache_path = f0_cache_path
|
self.f0_cache_path = f0_cache_path
|
||||||
self.min_seq_len = min_seq_len
|
self.min_audio_len = min_audio_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_audio_len = max_audio_len
|
||||||
|
self.min_text_len = min_text_len
|
||||||
|
self.max_text_len = max_text_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.speaker_id_mapping = speaker_id_mapping
|
self.speaker_id_mapping = speaker_id_mapping
|
||||||
|
@ -100,112 +130,113 @@ class TTSDataset(Dataset):
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
|
||||||
self.rescue_item_idx = 1
|
self.rescue_item_idx = 1
|
||||||
self.pitch_computed = False
|
self.pitch_computed = False
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
if self.tokenizer.use_phonemes and not os.path.isdir(phoneme_cache_path):
|
self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples)
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
|
||||||
|
if self.tokenizer.use_phonemes:
|
||||||
|
self.phoneme_dataset = PhonemeDataset(
|
||||||
|
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
|
||||||
|
)
|
||||||
|
|
||||||
if compute_f0:
|
if compute_f0:
|
||||||
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
|
self.f0_dataset = F0Dataset(
|
||||||
|
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||||
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.print_logs()
|
self.print_logs()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples(self):
|
||||||
|
return self._samples
|
||||||
|
|
||||||
|
@samples.setter
|
||||||
|
def samples(self, new_samples):
|
||||||
|
self._samples = new_samples
|
||||||
|
if hasattr(self, "f0_dataset"):
|
||||||
|
self.f0_dataset.samples = new_samples
|
||||||
|
if hasattr(self, "phoneme_dataset"):
|
||||||
|
self.phoneme_dataset.samples = new_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.load_data(idx)
|
||||||
|
|
||||||
def print_logs(self, level: int = 0) -> None:
|
def print_logs(self, level: int = 0) -> None:
|
||||||
indent = "\t" * level
|
indent = "\t" * level
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"{indent}> DataLoader initialization")
|
print(f"{indent}> DataLoader initialization")
|
||||||
print(f"{indent}| > Tokenizer:")
|
print(f"{indent}| > Tokenizer:")
|
||||||
self.tokenizer.print_logs(level + 1)
|
self.tokenizer.print_logs(level + 1)
|
||||||
print(f"{indent}| > Number of instances : {len(self.items)}")
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
audio = self.ap.load_wav(filename)
|
waveform = self.ap.load_wav(filename)
|
||||||
return audio
|
assert waveform.size > 0
|
||||||
|
return waveform
|
||||||
|
|
||||||
@staticmethod
|
def get_phonemes(self, idx, text):
|
||||||
def load_np(filename):
|
out_dict = self.phoneme_dataset[idx]
|
||||||
data = np.load(filename).astype("float32")
|
assert text == out_dict["text"], f"{text} != {out_dict['text']}"
|
||||||
return data
|
assert out_dict["token_ids"].size > 0
|
||||||
|
return out_dict
|
||||||
|
|
||||||
@staticmethod
|
def get_f0(self, idx):
|
||||||
def _generate_and_cache_phoneme_sequence(text, tokenizer, cache_path):
|
out_dict = self.f0_dataset[idx]
|
||||||
"""generate a phoneme sequence from text.
|
_, wav_file, *_ = _parse_sample(self.samples[idx])
|
||||||
since the usage is for subsequent caching, we never add bos and
|
assert wav_file == out_dict["audio_file"]
|
||||||
eos chars here. Instead we add those dynamically later; based on the
|
return out_dict
|
||||||
config option."""
|
|
||||||
phonemes = tokenizer.text_to_ids(text)
|
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
|
||||||
np.save(cache_path, phonemes)
|
|
||||||
return phonemes
|
|
||||||
|
|
||||||
@staticmethod
|
def get_attn_maks(self, attn_file):
|
||||||
def _load_or_generate_phoneme_sequence(wav_file, text, language, tokenizer, phoneme_cache_path):
|
return np.load(attn_file)
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
|
||||||
|
|
||||||
# different names for normal phonemes and with blank chars.
|
def get_token_ids(self, idx, text):
|
||||||
file_name_ext = "_phoneme.npy"
|
if self.tokenizer.use_phonemes:
|
||||||
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
|
token_ids = self.get_phonemes(idx, text)["token_ids"]
|
||||||
try:
|
else:
|
||||||
phonemes = np.load(cache_path)
|
token_ids = self.tokenizer.text_to_ids(text)
|
||||||
except FileNotFoundError:
|
return token_ids
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
|
|
||||||
except (ValueError, IOError):
|
|
||||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(text, tokenizer, cache_path)
|
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
|
||||||
return phonemes
|
|
||||||
|
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
item = self.items[idx]
|
item = self.samples[idx]
|
||||||
|
|
||||||
raw_text = item["text"]
|
raw_text = item["text"]
|
||||||
|
|
||||||
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
|
wav = np.asarray(self.load_wav(item[]), dtype=np.float32)
|
||||||
|
|
||||||
# apply noise for augmentation
|
# apply noise for augmentation
|
||||||
if self.use_noise_augment:
|
if self.use_noise_augment:
|
||||||
wav = wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
|
wav = noise_augment_audio(wav)
|
||||||
|
|
||||||
if not self.input_seq_computed:
|
# get token ids
|
||||||
if self.tokenizer.use_phonemes:
|
token_ids = self.get_token_ids(idx, item["text"])
|
||||||
text = self._load_or_generate_phoneme_sequence(
|
|
||||||
item["audio_file"],
|
|
||||||
item["text"],
|
|
||||||
item["language"] if item["language"] else self.phoneme_language,
|
|
||||||
self.tokenizer,
|
|
||||||
self.phoneme_cache_path,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = np.asarray(
|
|
||||||
self.tokenizer.text_to_ids(item["text"], item["language"]),
|
|
||||||
dtype=np.int32,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert text.size > 0, self.items[idx]["audio_file"]
|
|
||||||
assert wav.size > 0, self.items[idx]["audio_file"]
|
|
||||||
|
|
||||||
|
# get pre-computed attention maps
|
||||||
attn = None
|
attn = None
|
||||||
if "alignment_file" in item:
|
if "alignment_file" in item:
|
||||||
attn = np.load(item["alignment_file"])
|
attn = self.get_attn_mask(item["alignment_file"])
|
||||||
|
|
||||||
if len(text) > self.max_seq_len:
|
# after phonemization the text length may change
|
||||||
# return a different sample if the phonemized
|
# this is a shareful 🤭 hack to prevent longer phonemes
|
||||||
# text is longer than the threshold
|
# TODO: find a better fix
|
||||||
# TODO: find a better fix
|
if len(token_ids) > self.max_text_len:
|
||||||
return self.load_data(self.rescue_item_idx)
|
return self.load_data(self.rescue_item_idx)
|
||||||
|
|
||||||
pitch = None
|
# get f0 values
|
||||||
|
f0 = None
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
|
f0 = self.get_f0(idx)["f0"]
|
||||||
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"raw_text": raw_text,
|
"raw_text": raw_text,
|
||||||
"text": text,
|
"token_ids": token_ids,
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"pitch": pitch,
|
"pitch": f0,
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
"item_idx": item["audio_file"],
|
"item_idx": item["audio_file"],
|
||||||
"speaker_name": item["speaker_name"],
|
"speaker_name": item["speaker_name"],
|
||||||
|
@ -215,105 +246,78 @@ class TTSDataset(Dataset):
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _phoneme_worker(args):
|
def compute_lengths(samples):
|
||||||
item = args[0]
|
audio_lengths = []
|
||||||
func_args = args[1]
|
text_lengths = []
|
||||||
func_args[3] = (
|
for item in samples:
|
||||||
item["language"] if "language" in item and item["language"] else func_args[3]
|
text, wav_file, *_ = _parse_sample(item)
|
||||||
) # override phoneme language if specified by the dataset formatter
|
audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio
|
||||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
|
text_lengths.append(len(text))
|
||||||
return phonemes
|
audio_lengths = np.array(audio_lengths)
|
||||||
|
text_lengths = np.array(text_lengths)
|
||||||
|
return audio_lengths, text_lengths
|
||||||
|
|
||||||
def compute_input_seq(self, num_workers=0):
|
@staticmethod
|
||||||
"""Compute the input sequences with multi-processing.
|
def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int):
|
||||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
idxs = np.argsort(lengths) # ascending order
|
||||||
if not self.use_phonemes:
|
ignore_idx = []
|
||||||
if self.verbose:
|
keep_idx = []
|
||||||
print(" | > Computing input sequences ...")
|
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
|
||||||
sequence = np.asarray(
|
|
||||||
self.tokenizer.text_to_ids(item["text"]),
|
|
||||||
dtype=np.int32,
|
|
||||||
)
|
|
||||||
self.items[idx][0] = sequence
|
|
||||||
else:
|
|
||||||
func_args = [
|
|
||||||
self.phoneme_cache_path,
|
|
||||||
self.enable_eos_bos,
|
|
||||||
self.cleaners,
|
|
||||||
self.phoneme_language,
|
|
||||||
self.characters,
|
|
||||||
self.add_blank,
|
|
||||||
]
|
|
||||||
if self.verbose:
|
|
||||||
print(" | > Computing phonemes ...")
|
|
||||||
if num_workers == 0:
|
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
|
||||||
phonemes = self._phoneme_worker([item, func_args])
|
|
||||||
self.items[idx][0] = phonemes
|
|
||||||
else:
|
|
||||||
with Pool(num_workers) as p:
|
|
||||||
phonemes = list(
|
|
||||||
tqdm.tqdm(
|
|
||||||
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
|
||||||
total=len(self.items),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for idx, p in enumerate(phonemes):
|
|
||||||
self.items[idx][0] = p
|
|
||||||
|
|
||||||
def sort_and_filter_items(self, by_audio_len=False):
|
|
||||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
|
||||||
range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
by_audio_len (bool): if True, sort by audio length else by text length.
|
|
||||||
"""
|
|
||||||
# compute the target sequence length
|
|
||||||
if by_audio_len:
|
|
||||||
lengths = []
|
|
||||||
for item in self.items:
|
|
||||||
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
|
|
||||||
lengths = np.array(lengths)
|
|
||||||
else:
|
|
||||||
lengths = np.array([len(ins["text"]) for ins in self.items])
|
|
||||||
|
|
||||||
idxs = np.argsort(lengths)
|
|
||||||
new_items = []
|
|
||||||
ignored = []
|
|
||||||
for i, idx in enumerate(idxs):
|
for i, idx in enumerate(idxs):
|
||||||
length = lengths[idx]
|
length = lengths[idx]
|
||||||
if length < self.min_seq_len or length > self.max_seq_len:
|
if length < min_len or length > max_len:
|
||||||
ignored.append(idx)
|
ignore_idx.append(idx)
|
||||||
else:
|
else:
|
||||||
new_items.append(self.items[idx])
|
keep_idx.append(idx)
|
||||||
|
return ignore_idx, keep_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_buckets(samples, batch_group_size:int):
|
||||||
|
for i in range(len(samples) // batch_group_size):
|
||||||
|
offset = i * batch_group_size
|
||||||
|
end_offset = offset + batch_group_size
|
||||||
|
temp_items = samples[offset:end_offset]
|
||||||
|
random.shuffle(temp_items)
|
||||||
|
samples[offset:end_offset] = temp_items
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def preprocess_samples(self):
|
||||||
|
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||||
|
range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# sort items based on the sequence length in ascending order
|
||||||
|
text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
|
||||||
|
audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len)
|
||||||
|
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
|
||||||
|
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for idx in keep_idx:
|
||||||
|
samples.append(self.samples[idx])
|
||||||
|
|
||||||
|
if len(samples) == 0:
|
||||||
|
raise RuntimeError(" [!] No samples left")
|
||||||
|
|
||||||
# shuffle batch groups
|
# shuffle batch groups
|
||||||
if self.batch_group_size > 0:
|
# create batches with similar length items
|
||||||
for i in range(len(new_items) // self.batch_group_size):
|
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||||
offset = i * self.batch_group_size
|
samples = self.create_buckets(samples, self.batch_group_size)
|
||||||
end_offset = offset + self.batch_group_size
|
|
||||||
temp_items = new_items[offset:end_offset]
|
# update items to the new sorted items
|
||||||
random.shuffle(temp_items)
|
self.samples = samples
|
||||||
new_items[offset:end_offset] = temp_items
|
|
||||||
self.items = new_items
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Preprocessing samples")
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Max text length: {}".format(np.max(self.text_lengths)))
|
||||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
print(" | > Min text length: {}".format(np.min(self.text_lengths)))
|
||||||
print(
|
print(" | > Avg text length: {}".format(np.mean(self.text_lengths)))
|
||||||
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
|
print(" | ")
|
||||||
self.max_seq_len, self.min_seq_len, len(ignored)
|
print(" | > Max audio length: {}".format(np.max(self.audio_lengths)))
|
||||||
)
|
print(" | > Min audio length: {}".format(np.min(self.audio_lengths)))
|
||||||
)
|
print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths)))
|
||||||
|
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
|
||||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.items)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.load_data(idx)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sort_batch(batch, text_lengths):
|
def _sort_batch(batch, text_lengths):
|
||||||
"""Sort the batch by the input text length for RNN efficiency.
|
"""Sort the batch by the input text length for RNN efficiency.
|
||||||
|
@ -338,10 +342,10 @@ class TTSDataset(Dataset):
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
if isinstance(batch[0], collections.abc.Mapping):
|
if isinstance(batch[0], collections.abc.Mapping):
|
||||||
|
|
||||||
text_lengths = np.array([len(d["text"]) for d in batch])
|
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])
|
||||||
|
|
||||||
# sort items with text input length for RNN efficiency
|
# sort items with text input length for RNN efficiency
|
||||||
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
|
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)
|
||||||
|
|
||||||
# convert list of dicts to dict of lists
|
# convert list of dicts to dict of lists
|
||||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
|
@ -383,7 +387,7 @@ class TTSDataset(Dataset):
|
||||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with longest instance in the batch
|
# PAD sequences with longest instance in the batch
|
||||||
text = prepare_data(batch["text"]).astype(np.int32)
|
text = prepare_data(batch["token_ids"]).astype(np.int32)
|
||||||
|
|
||||||
# PAD features with longest instance
|
# PAD features with longest instance
|
||||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
@ -392,12 +396,13 @@ class TTSDataset(Dataset):
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
# convert things to pytorch
|
# convert things to pytorch
|
||||||
text_lengths = torch.LongTensor(text_lengths)
|
token_ids_lengths = torch.LongTensor(token_ids_lengths)
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
mel = torch.FloatTensor(mel).contiguous()
|
mel = torch.FloatTensor(mel).contiguous()
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
# speaker vectors
|
||||||
if d_vectors is not None:
|
if d_vectors is not None:
|
||||||
d_vectors = torch.FloatTensor(d_vectors)
|
d_vectors = torch.FloatTensor(d_vectors)
|
||||||
|
|
||||||
|
@ -408,14 +413,13 @@ class TTSDataset(Dataset):
|
||||||
language_ids = torch.LongTensor(language_ids)
|
language_ids = torch.LongTensor(language_ids)
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
|
linear = None
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
assert mel.shape[1] == linear.shape[1]
|
assert mel.shape[1] == linear.shape[1]
|
||||||
linear = torch.FloatTensor(linear).contiguous()
|
linear = torch.FloatTensor(linear).contiguous()
|
||||||
else:
|
|
||||||
linear = None
|
|
||||||
|
|
||||||
# format waveforms
|
# format waveforms
|
||||||
wav_padded = None
|
wav_padded = None
|
||||||
|
@ -431,8 +435,7 @@ class TTSDataset(Dataset):
|
||||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||||
wav_padded.transpose_(1, 2)
|
wav_padded.transpose_(1, 2)
|
||||||
|
|
||||||
# compute f0
|
# format F0
|
||||||
# TODO: compare perf in collate_fn vs in load_data
|
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
pitch = prepare_data(batch["pitch"])
|
pitch = prepare_data(batch["pitch"])
|
||||||
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||||
|
@ -440,7 +443,8 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
pitch = None
|
pitch = None
|
||||||
|
|
||||||
# collate attention alignments
|
# format attention masks
|
||||||
|
attns = None
|
||||||
if batch["attn"][0] is not None:
|
if batch["attn"][0] is not None:
|
||||||
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||||
for idx, attn in enumerate(attns):
|
for idx, attn in enumerate(attns):
|
||||||
|
@ -451,12 +455,10 @@ class TTSDataset(Dataset):
|
||||||
attns[idx] = attn
|
attns[idx] = attn
|
||||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||||
else:
|
|
||||||
attns = None
|
|
||||||
# TODO: return dictionary
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"token_id": text,
|
||||||
"text_lengths": text_lengths,
|
"token_id_lengths": token_ids_lengths,
|
||||||
"speaker_names": batch["speaker_name"],
|
"speaker_names": batch["speaker_name"],
|
||||||
"linear": linear,
|
"linear": linear,
|
||||||
"mel": mel,
|
"mel": mel,
|
||||||
|
@ -482,22 +484,179 @@ class TTSDataset(Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PitchExtractor:
|
class PhonemeDataset(Dataset):
|
||||||
"""Pitch Extractor for computing F0 from wav files.
|
"""Phoneme Dataset for converting input text to phonemes and then token IDs
|
||||||
|
|
||||||
|
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
|
||||||
|
loading latency. If `cache_path` is already present, it skips the pre-computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items (List[List]): Dataset samples.
|
samples (Union[List[List], List[Dict]]):
|
||||||
verbose (bool): Whether to print the progress.
|
List of samples. Each sample is a list or a dict.
|
||||||
|
|
||||||
|
tokenizer (TTSTokenizer):
|
||||||
|
Tokenizer to convert input text to phonemes.
|
||||||
|
|
||||||
|
cache_path (str):
|
||||||
|
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.
|
||||||
|
|
||||||
|
precompute_num_workers (int):
|
||||||
|
Number of workers used for pre-computing the phonemes. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
items: List[Dict],
|
samples: Union[List[Dict], List[List]],
|
||||||
verbose=False,
|
tokenizer: "TTSTokenizer",
|
||||||
|
cache_path: str,
|
||||||
|
precompute_num_workers=0,
|
||||||
):
|
):
|
||||||
self.items = items
|
self.samples = samples
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.cache_path = cache_path
|
||||||
|
if cache_path is not None and not os.path.exists(cache_path):
|
||||||
|
os.makedirs(cache_path)
|
||||||
|
self.precompute(precompute_num_workers)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
text, wav_file, *_ = _parse_sample(self.samples[index])
|
||||||
|
ids = self.compute_or_load(wav_file, text)
|
||||||
|
ph_hat = self.tokenizer.ids_to_text(ids)
|
||||||
|
return {"text": text, "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def compute_or_load(self, wav_file, text):
|
||||||
|
"""Compute phonemes for the given text.
|
||||||
|
|
||||||
|
If the phonemes are already cached, load them from cache.
|
||||||
|
"""
|
||||||
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
file_ext = "_phoneme.npy"
|
||||||
|
cache_path = os.path.join(self.cache_path, file_name + file_ext)
|
||||||
|
try:
|
||||||
|
ids = np.load(cache_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
ids = self.tokenizer.text_to_ids(text)
|
||||||
|
np.save(cache_path, ids)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def get_pad_id(self):
|
||||||
|
"""Get pad token ID for sequence padding"""
|
||||||
|
return self.tokenizer.pad_id
|
||||||
|
|
||||||
|
def precompute(self, num_workers=1):
|
||||||
|
"""Precompute phonemes for all samples.
|
||||||
|
|
||||||
|
We use pytorch dataloader because we are lazy.
|
||||||
|
"""
|
||||||
|
with tqdm.tqdm(total=len(self)) as pbar:
|
||||||
|
batch_size = num_workers if num_workers > 0 else 1
|
||||||
|
dataloder = torch.utils.data.DataLoader(
|
||||||
|
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||||
|
)
|
||||||
|
for _ in dataloder:
|
||||||
|
pbar.update(batch_size)
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
ids = [item["token_ids"] for item in batch]
|
||||||
|
ids_lens = [item["token_ids_len"] for item in batch]
|
||||||
|
texts = [item["text"] for item in batch]
|
||||||
|
texts_hat = [item["ph_hat"] for item in batch]
|
||||||
|
ids_lens_max = max(ids_lens)
|
||||||
|
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
|
||||||
|
for i, ids_len in enumerate(ids_lens):
|
||||||
|
ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
|
||||||
|
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}
|
||||||
|
|
||||||
|
def print_logs(self, level: int = 0) -> None:
|
||||||
|
indent = "\t" * level
|
||||||
|
print("\n")
|
||||||
|
print(f"{indent}> PhonemeDataset ")
|
||||||
|
print(f"{indent}| > Tokenizer:")
|
||||||
|
self.tokenizer.print_logs(level + 1)
|
||||||
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
||||||
|
|
||||||
|
class F0Dataset:
|
||||||
|
"""F0 Dataset for computing F0 from wav files in CPU
|
||||||
|
|
||||||
|
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||||
|
also computes the mean and std of F0 values if `normalize_f0` is True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (Union[List[List], List[Dict]]):
|
||||||
|
List of samples. Each sample is a list or a dict.
|
||||||
|
|
||||||
|
ap (AudioProcessor):
|
||||||
|
AudioProcessor to compute F0 from wav files.
|
||||||
|
|
||||||
|
cache_path (str):
|
||||||
|
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
precompute_num_workers (int):
|
||||||
|
Number of workers used for pre-computing the F0 values. Defaults to 0.
|
||||||
|
|
||||||
|
normalize_f0 (bool):
|
||||||
|
Whether to normalize F0 values by mean and std. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samples: Union[List[List], List[Dict]],
|
||||||
|
ap: "AudioProcessor",
|
||||||
|
verbose=False,
|
||||||
|
cache_path: str = None,
|
||||||
|
precompute_num_workers=0,
|
||||||
|
normalize_f0=True,
|
||||||
|
):
|
||||||
|
self.samples = samples
|
||||||
|
self.ap = ap
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self.cache_path = cache_path
|
||||||
|
self.normalize_f0 = normalize_f0
|
||||||
|
self.pad_id = 0.0
|
||||||
self.mean = None
|
self.mean = None
|
||||||
self.std = None
|
self.std = None
|
||||||
|
if cache_path is not None and not os.path.exists(cache_path):
|
||||||
|
os.makedirs(cache_path)
|
||||||
|
self.precompute(precompute_num_workers)
|
||||||
|
if normalize_f0:
|
||||||
|
self.load_stats(cache_path)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
_, wav_file, *_ = _parse_sample(self.samples[idx])
|
||||||
|
f0 = self.compute_or_load(wav_file)
|
||||||
|
if self.normalize_f0:
|
||||||
|
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||||
|
f0 = self.normalize(f0)
|
||||||
|
return {"audio_file": wav_file, "f0": f0}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def precompute(self, num_workers=0):
|
||||||
|
with tqdm.tqdm(total=len(self)) as pbar:
|
||||||
|
batch_size = num_workers if num_workers > 0 else 1
|
||||||
|
dataloder = torch.utils.data.DataLoader(
|
||||||
|
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||||
|
)
|
||||||
|
computed_data = []
|
||||||
|
for batch in dataloder:
|
||||||
|
f0 = batch["f0"]
|
||||||
|
computed_data.append([f for f in f0])
|
||||||
|
pbar.update(batch_size)
|
||||||
|
|
||||||
|
if self.normalize_f0:
|
||||||
|
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||||
|
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
|
||||||
|
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||||
|
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||||
|
|
||||||
|
def get_pad_id(self):
|
||||||
|
return self.pad_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_pitch_file_path(wav_file, cache_path):
|
def create_pitch_file_path(wav_file, cache_path):
|
||||||
|
@ -519,69 +678,128 @@ class PitchExtractor:
|
||||||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
def normalize_pitch(self, pitch):
|
def load_stats(self, cache_path):
|
||||||
|
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||||
|
stats = np.load(stats_path, allow_pickle=True).item()
|
||||||
|
self.mean = stats["mean"].astype(np.float32)
|
||||||
|
self.std = stats["std"].astype(np.float32)
|
||||||
|
|
||||||
|
def normalize(self, pitch):
|
||||||
zero_idxs = np.where(pitch == 0.0)[0]
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
pitch = pitch - self.mean
|
pitch = pitch - self.mean
|
||||||
pitch = pitch / self.std
|
pitch = pitch / self.std
|
||||||
pitch[zero_idxs] = 0.0
|
pitch[zero_idxs] = 0.0
|
||||||
return pitch
|
return pitch
|
||||||
|
|
||||||
def denormalize_pitch(self, pitch):
|
def denormalize(self, pitch):
|
||||||
zero_idxs = np.where(pitch == 0.0)[0]
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
pitch *= self.std
|
pitch *= self.std
|
||||||
pitch += self.mean
|
pitch += self.mean
|
||||||
pitch[zero_idxs] = 0.0
|
pitch[zero_idxs] = 0.0
|
||||||
return pitch
|
return pitch
|
||||||
|
|
||||||
@staticmethod
|
def compute_or_load(self, wav_file):
|
||||||
def load_or_compute_pitch(ap, wav_file, cache_path):
|
|
||||||
"""
|
"""
|
||||||
compute pitch and return a numpy array of pitch values
|
compute pitch and return a numpy array of pitch values
|
||||||
"""
|
"""
|
||||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
|
||||||
if not os.path.exists(pitch_file):
|
if not os.path.exists(pitch_file):
|
||||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
|
||||||
else:
|
else:
|
||||||
pitch = np.load(pitch_file)
|
pitch = np.load(pitch_file)
|
||||||
return pitch.astype(np.float32)
|
return pitch.astype(np.float32)
|
||||||
|
|
||||||
@staticmethod
|
def collate_fn(self, batch):
|
||||||
def _pitch_worker(args):
|
audio_file = [item["audio_file"] for item in batch]
|
||||||
item = args[0]
|
f0s = [item["f0"] for item in batch]
|
||||||
ap = args[1]
|
f0_lens = [len(item["f0"]) for item in batch]
|
||||||
cache_path = args[2]
|
f0_lens_max = max(f0_lens)
|
||||||
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
|
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
|
||||||
if not os.path.exists(pitch_file):
|
for i, f0_len in enumerate(f0_lens):
|
||||||
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
|
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
|
||||||
return pitch
|
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_pitch(self, ap, cache_path, num_workers=0):
|
def print_logs(self, level: int = 0) -> None:
|
||||||
"""Compute the input sequences with multi-processing.
|
indent = "\t" * level
|
||||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
print("\n")
|
||||||
if not os.path.exists(cache_path):
|
print(f"{indent}> F0Dataset ")
|
||||||
os.makedirs(cache_path, exist_ok=True)
|
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
print(" | > Computing pitch features ...")
|
|
||||||
if num_workers == 0:
|
|
||||||
pitch_vecs = []
|
|
||||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
|
||||||
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
|
||||||
else:
|
|
||||||
with Pool(num_workers) as p:
|
|
||||||
pitch_vecs = list(
|
|
||||||
tqdm.tqdm(
|
|
||||||
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
|
||||||
total=len(self.items),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
|
||||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
|
||||||
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
|
||||||
|
|
||||||
def load_pitch_stats(self, cache_path):
|
if __name__ == "__main__":
|
||||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
from torch.utils.data import DataLoader
|
||||||
stats = np.load(stats_path, allow_pickle=True).item()
|
|
||||||
self.mean = stats["mean"].astype(np.float32)
|
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
|
||||||
self.std = stats["std"].astype(np.float32)
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.utils.text.characters import IPAPhonemes
|
||||||
|
from TTS.tts.utils.text.phonemizers import ESpeak
|
||||||
|
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
path="/Users/erengolge/Projects/TTS/recipes/ljspeech/LJSpeech-1.1",
|
||||||
|
)
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
samples = train_samples + eval_samples
|
||||||
|
|
||||||
|
phonemizer = ESpeak(language="en-us")
|
||||||
|
tokenizer = TTSTokenizer(use_phonemes=True, characters=IPAPhonemes(), phonemizer=phonemizer)
|
||||||
|
# ph_dataset = PhonemeDataset(samples, tokenizer, phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests")
|
||||||
|
# ph_dataset.precompute(num_workers=4)
|
||||||
|
|
||||||
|
# dataloader = DataLoader(ph_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=ph_dataset.collate_fn)
|
||||||
|
# for batch in dataloader:
|
||||||
|
# print(batch)
|
||||||
|
# break
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
win_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
num_mels=80,
|
||||||
|
preemphasis=0.0,
|
||||||
|
ref_level_db=20,
|
||||||
|
log_func="np.log",
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=45,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
signal_norm=False,
|
||||||
|
do_amp_to_db_linear=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ap = AudioProcessor.init_from_config(audio_config)
|
||||||
|
|
||||||
|
# f0_dataset = F0Dataset(samples, ap, cache_path="/Users/erengolge/Projects/TTS/f0_tests", verbose=False, precompute_num_workers=4)
|
||||||
|
|
||||||
|
# dataloader = DataLoader(f0_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=f0_dataset.collate_fn)
|
||||||
|
# for batch in dataloader:
|
||||||
|
# print(batch)
|
||||||
|
# breakpoint()
|
||||||
|
# break
|
||||||
|
|
||||||
|
dataset = TTSDataset(
|
||||||
|
outputs_per_step=1,
|
||||||
|
compute_linear_spec=False,
|
||||||
|
meta_data=samples,
|
||||||
|
ap=ap,
|
||||||
|
return_wav=False,
|
||||||
|
batch_group_size=0,
|
||||||
|
min_seq_len=0,
|
||||||
|
max_seq_len=500,
|
||||||
|
use_noise_augment=False,
|
||||||
|
verbose=True,
|
||||||
|
speaker_id_mapping=None,
|
||||||
|
d_vector_mapping=None,
|
||||||
|
compute_f0=True,
|
||||||
|
f0_cache_path="/Users/erengolge/Projects/TTS/f0_tests",
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
phoneme_cache_path="/Users/erengolge/Projects/TTS/phonemes_tests",
|
||||||
|
precompute_num_workers=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
|
||||||
|
for batch in dataloader:
|
||||||
|
print(batch)
|
||||||
|
break
|
||||||
|
|
Loading…
Reference in New Issue