Cache pitch features

Cache the features at the beginning of `BaseTTS` training.
This commit is contained in:
Eren Gölge 2021-07-12 18:23:19 +02:00
parent 7590c7db7a
commit d085642ac1
3 changed files with 140 additions and 23 deletions

View File

@ -9,7 +9,7 @@ import torch
import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
@ -23,6 +23,7 @@ class TTSDataset(Dataset):
ap: AudioProcessor,
meta_data: List[List],
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
@ -41,8 +42,7 @@ class TTSDataset(Dataset):
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can either override or create a new class as the dataset is
initialized by the model.
If you need something different, you can inherit and override.
Args:
outputs_per_step (int): Number of time frames predicted per step.
@ -57,6 +57,8 @@ class TTSDataset(Dataset):
compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
@ -81,8 +83,8 @@ class TTSDataset(Dataset):
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
the coming iterations. Defaults to None.
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
@ -107,6 +109,7 @@ class TTSDataset(Dataset):
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
@ -123,6 +126,7 @@ class TTSDataset(Dataset):
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
self.pitch_computed = False
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
@ -240,10 +244,14 @@ class TTSDataset(Dataset):
# TODO: find a better fix
return self.load_data(self.rescue_item_idx)
if self.compute_f0:
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
sample = {
"raw_text": raw_text,
"text": text,
"wav": wav,
"pitch": pitch,
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
@ -260,8 +268,8 @@ class TTSDataset(Dataset):
return phonemes
def compute_input_seq(self, num_workers=0):
"""compute input sequences separately. Call it before
passing dataset to data loader."""
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
@ -306,6 +314,64 @@ class TTSDataset(Dataset):
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
@staticmethod
def _load_or_compute_pitch(ap, wav_file, cache_path):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
_, wav_file, *_ = item
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
return True
return False
def compute_pitch(self, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
self._pitch_worker([item, self.ap, cache_path])
else:
with Pool(num_workers) as p:
_ = list(
tqdm.tqdm(
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
@ -367,7 +433,7 @@ class TTSDataset(Dataset):
r"""
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to Spectrograms.
2. Convert Audio signal to features.
3. PAD sequences wrt r.
4. Load to Torch.
"""
@ -466,11 +532,12 @@ class TTSDataset(Dataset):
# TODO: compare perf in collate_fn vs in load_data
pitch = None
if self.compute_f0:
pitch = [self.ap.compute_f0(w).astype("float32") for w in wav]
pitch = prepare_tensor(pitch, self.outputs_per_step)
pitch = pitch.transpose(0, 2, 1)
assert mel.shape[1] == pitch.shape[1]
pitch = torch.FloatTensor(pitch).contiguous()
pitch = [b["pitch"] for b in batch]
pitch = prepare_data(pitch)
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else:
pitch = None
# collate attention alignments
if batch[0]["attn"] is not None:
@ -478,6 +545,7 @@ class TTSDataset(Dataset):
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step)
@ -499,6 +567,7 @@ class TTSDataset(Dataset):
attns,
wav_padded,
raw_text,
pitch,
)
raise TypeError(

View File

@ -116,6 +116,7 @@ class BaseTTS(BaseModel):
speaker_ids = batch[9]
attn_mask = batch[10]
waveform = batch[11]
pitch = batch[13]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -162,6 +163,7 @@ class BaseTTS(BaseModel):
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
}
def get_data_loader(
@ -200,6 +202,7 @@ class BaseTTS(BaseModel):
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
comnpute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items,
ap=ap,
characters=config.characters,
@ -246,6 +249,14 @@ class BaseTTS(BaseModel):
# sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# compute pitch frames and write to files.
if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]:
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -9,8 +9,7 @@ import torch
from torch import nn
from TTS.tts.utils.data import StandardScaler
# import pyworld as pw
from TTS.utils.yin import compute_yin
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
@ -648,15 +647,53 @@ class AudioProcessor(object):
# frame_period=1000 * self.hop_length / self.sample_rate,
# )
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
# f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size)
f0, _, _ = librosa.pyin(
x.astype(np.double),
fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
fmax=self.mel_fmax,
frame_length=self.win_length,
sr=self.sample_rate,
fill_na=0.0,
f0, _, _, _ = compute_yin(
x,
self.sample_rate,
self.win_length,
self.hop_length,
65 if self.mel_fmin == 0 else self.mel_fmin,
self.mel_fmax,
)
# import pyworld as pw
# f0, _ = pw.dio(x.astype(np.float64), self.sample_rate,
# frame_period=self.hop_length / self.sample_rate * 1000)
pad = int((self.win_length / self.hop_length) / 2)
f0 = [0.0] * pad + f0 + [0.0] * pad
f0 = np.array(f0, dtype=np.float32)
# f01, _, _ = librosa.pyin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate,
# fill_na=0.0,
# )
# f02 = librosa.yin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate
# )
# spec = self.melspectrogram(x)
# from matplotlib import pyplot as plt
# plt.figure()
# plt.plot(f0, linewidth=2.5, color='red')
# plt.plot(f01, linewidth=2.5, linestyle='-.')
# plt.plot(f02, linewidth=2.5)
# plt.xlabel('time')
# plt.ylabel('F0')
# plt.savefig('save_img.png')
# # plt.figure()
# plt.imshow(spec, aspect="auto", origin="lower")
# plt.savefig('save_img2.png')
# breakpoint()
return f0
### Audio Processing ###