mirror of https://github.com/coqui-ai/TTS.git
Merge dataset
This commit is contained in:
parent
c80cf67d3d
commit
5c89803968
|
@ -56,10 +56,6 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
meta_data (list): List of dataset instances.
|
meta_data (list): List of dataset instances.
|
||||||
|
|
||||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
|
||||||
|
|
||||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
|
||||||
|
|
||||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||||
|
|
||||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||||
|
@ -109,8 +105,6 @@ class TTSDataset(Dataset):
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
self.return_wav = return_wav
|
self.return_wav = return_wav
|
||||||
self.compute_f0 = compute_f0
|
|
||||||
self.f0_cache_path = f0_cache_path
|
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
@ -339,7 +333,6 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||||
|
|
||||||
# sort items based on the sequence length in ascending order
|
|
||||||
idxs = np.argsort(lengths)
|
idxs = np.argsort(lengths)
|
||||||
new_items = []
|
new_items = []
|
||||||
ignored = []
|
ignored = []
|
||||||
|
@ -349,10 +342,7 @@ class TTSDataset(Dataset):
|
||||||
ignored.append(idx)
|
ignored.append(idx)
|
||||||
else:
|
else:
|
||||||
new_items.append(self.items[idx])
|
new_items.append(self.items[idx])
|
||||||
|
|
||||||
# shuffle batch groups
|
# shuffle batch groups
|
||||||
# create batches with similar length items
|
|
||||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
|
||||||
if self.batch_group_size > 0:
|
if self.batch_group_size > 0:
|
||||||
for i in range(len(new_items) // self.batch_group_size):
|
for i in range(len(new_items) // self.batch_group_size):
|
||||||
offset = i * self.batch_group_size
|
offset = i * self.batch_group_size
|
||||||
|
@ -360,14 +350,8 @@ class TTSDataset(Dataset):
|
||||||
temp_items = new_items[offset:end_offset]
|
temp_items = new_items[offset:end_offset]
|
||||||
random.shuffle(temp_items)
|
random.shuffle(temp_items)
|
||||||
new_items[offset:end_offset] = temp_items
|
new_items[offset:end_offset] = temp_items
|
||||||
|
|
||||||
if len(new_items) == 0:
|
|
||||||
raise RuntimeError(" [!] No items left after filtering.")
|
|
||||||
|
|
||||||
# update items to the new sorted items
|
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
# logging
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
|
@ -554,110 +538,3 @@ class TTSDataset(Dataset):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PitchExtractor:
|
|
||||||
"""Pitch Extractor for computing F0 from wav files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items (List[List]): Dataset samples.
|
|
||||||
verbose (bool): Whether to print the progress.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
items: List[List],
|
|
||||||
verbose=False,
|
|
||||||
):
|
|
||||||
self.items = items
|
|
||||||
self.verbose = verbose
|
|
||||||
self.mean = None
|
|
||||||
self.std = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_pitch_file_path(wav_file, cache_path):
|
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
|
||||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
|
||||||
return pitch_file
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
|
||||||
wav = ap.load_wav(wav_file)
|
|
||||||
pitch = ap.compute_f0(wav)
|
|
||||||
if pitch_file:
|
|
||||||
np.save(pitch_file, pitch)
|
|
||||||
return pitch
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def compute_pitch_stats(pitch_vecs):
|
|
||||||
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
|
|
||||||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
def normalize_pitch(self, pitch):
|
|
||||||
zero_idxs = np.where(pitch == 0.0)[0]
|
|
||||||
pitch = pitch - self.mean
|
|
||||||
pitch = pitch / self.std
|
|
||||||
pitch[zero_idxs] = 0.0
|
|
||||||
return pitch
|
|
||||||
|
|
||||||
def denormalize_pitch(self, pitch):
|
|
||||||
zero_idxs = np.where(pitch == 0.0)[0]
|
|
||||||
pitch *= self.std
|
|
||||||
pitch += self.mean
|
|
||||||
pitch[zero_idxs] = 0.0
|
|
||||||
return pitch
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_or_compute_pitch(ap, wav_file, cache_path):
|
|
||||||
"""
|
|
||||||
compute pitch and return a numpy array of pitch values
|
|
||||||
"""
|
|
||||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
|
||||||
if not os.path.exists(pitch_file):
|
|
||||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
|
||||||
else:
|
|
||||||
pitch = np.load(pitch_file)
|
|
||||||
return pitch.astype(np.float32)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _pitch_worker(args):
|
|
||||||
item = args[0]
|
|
||||||
ap = args[1]
|
|
||||||
cache_path = args[2]
|
|
||||||
_, wav_file, *_ = item
|
|
||||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
|
||||||
if not os.path.exists(pitch_file):
|
|
||||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
|
||||||
return pitch
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_pitch(self, ap, cache_path, num_workers=0):
|
|
||||||
"""Compute the input sequences with multi-processing.
|
|
||||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
|
||||||
if not os.path.exists(cache_path):
|
|
||||||
os.makedirs(cache_path, exist_ok=True)
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
print(" | > Computing pitch features ...")
|
|
||||||
if num_workers == 0:
|
|
||||||
pitch_vecs = []
|
|
||||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
|
||||||
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
|
||||||
else:
|
|
||||||
with Pool(num_workers) as p:
|
|
||||||
pitch_vecs = list(
|
|
||||||
tqdm.tqdm(
|
|
||||||
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
|
||||||
total=len(self.items),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
|
||||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
|
||||||
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
|
||||||
|
|
||||||
def load_pitch_stats(self, cache_path):
|
|
||||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
|
||||||
stats = np.load(stats_path, allow_pickle=True).item()
|
|
||||||
self.mean = stats["mean"].astype(np.float32)
|
|
||||||
self.std = stats["std"].astype(np.float32)
|
|
||||||
|
|
Loading…
Reference in New Issue