mirror of https://github.com/coqui-ai/TTS.git
Compute mean and std pitch
This commit is contained in:
parent
8fffd4e813
commit
e802b24ad0
|
@ -127,6 +127,7 @@ class TTSDataset(Dataset):
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
self.rescue_item_idx = 1
|
self.rescue_item_idx = 1
|
||||||
self.pitch_computed = False
|
self.pitch_computed = False
|
||||||
|
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -247,6 +248,7 @@ class TTSDataset(Dataset):
|
||||||
pitch = None
|
pitch = None
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||||
|
pitch = self.normalize_pitch(pitch)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"raw_text": raw_text,
|
"raw_text": raw_text,
|
||||||
|
@ -315,6 +317,11 @@ class TTSDataset(Dataset):
|
||||||
for idx, p in enumerate(phonemes):
|
for idx, p in enumerate(phonemes):
|
||||||
self.items[idx][0] = p
|
self.items[idx][0] = p
|
||||||
|
|
||||||
|
################
|
||||||
|
# Pitch Methods
|
||||||
|
###############
|
||||||
|
# TODO: Refactor Pitch methods into a separate class
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_pitch_file_path(wav_file, cache_path):
|
def create_pitch_file_path(wav_file, cache_path):
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
@ -329,6 +336,19 @@ class TTSDataset(Dataset):
|
||||||
np.save(pitch_file, pitch)
|
np.save(pitch_file, pitch)
|
||||||
return pitch
|
return pitch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_pitch_stats(pitch_vecs):
|
||||||
|
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
|
||||||
|
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
def normalize_pitch(self, pitch):
|
||||||
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
|
pitch -= self.mean
|
||||||
|
pitch /= self.std
|
||||||
|
pitch[zero_idxs] = 0.0
|
||||||
|
return pitch
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
||||||
"""
|
"""
|
||||||
|
@ -349,9 +369,9 @@ class TTSDataset(Dataset):
|
||||||
_, wav_file, *_ = item
|
_, wav_file, *_ = item
|
||||||
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
||||||
if not os.path.exists(pitch_file):
|
if not os.path.exists(pitch_file):
|
||||||
TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||||
return True
|
return pitch
|
||||||
return False
|
return None
|
||||||
|
|
||||||
def compute_pitch(self, cache_path, num_workers=0):
|
def compute_pitch(self, cache_path, num_workers=0):
|
||||||
"""Compute the input sequences with multi-processing.
|
"""Compute the input sequences with multi-processing.
|
||||||
|
@ -362,16 +382,30 @@ class TTSDataset(Dataset):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(" | > Computing pitch features ...")
|
print(" | > Computing pitch features ...")
|
||||||
if num_workers == 0:
|
if num_workers == 0:
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
pitch_vecs = []
|
||||||
self._pitch_worker([item, self.ap, cache_path])
|
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
|
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
|
||||||
else:
|
else:
|
||||||
with Pool(num_workers) as p:
|
with Pool(num_workers) as p:
|
||||||
_ = list(
|
pitch_vecs = list(
|
||||||
tqdm.tqdm(
|
tqdm.tqdm(
|
||||||
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
||||||
total=len(self.items),
|
total=len(self.items),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
||||||
|
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||||
|
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||||
|
|
||||||
|
def load_pitch_stats(self, cache_path):
|
||||||
|
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||||
|
stats = np.load(stats_path, allow_pickle=True).item()
|
||||||
|
self.mean = stats["mean"]
|
||||||
|
self.std = stats["std"]
|
||||||
|
|
||||||
|
###################
|
||||||
|
# End Pitch Methods
|
||||||
|
###################
|
||||||
|
|
||||||
def sort_and_filter_items(self, by_audio_len=False):
|
def sort_and_filter_items(self, by_audio_len=False):
|
||||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||||
|
|
|
@ -250,8 +250,10 @@ class BaseTTS(BaseModel):
|
||||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||||
|
|
||||||
# compute pitch frames and write to files.
|
# compute pitch frames and write to files.
|
||||||
if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]:
|
if config.compute_f0 and rank in [None, 0]:
|
||||||
|
if not os.path.exists(config.f0_cache_path):
|
||||||
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
||||||
|
dataset.load_pitch_stats(config.get("f0_cache_path", None))
|
||||||
|
|
||||||
# halt DDP processes for the main process to finish computing the F0 cache
|
# halt DDP processes for the main process to finish computing the F0 cache
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
|
Loading…
Reference in New Issue