Cache pitch features

Cache the features at the beginning of `BaseTTS` training.
This commit is contained in:
Eren Gölge 2021-07-12 18:23:19 +02:00
parent 7590c7db7a
commit d085642ac1
3 changed files with 140 additions and 23 deletions

View File

@ -9,7 +9,7 @@ import torch
import tqdm import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -23,6 +23,7 @@ class TTSDataset(Dataset):
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[List], meta_data: List[List],
compute_f0: bool = False, compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None, characters: Dict = None,
custom_symbols: List = None, custom_symbols: List = None,
add_blank: bool = False, add_blank: bool = False,
@ -41,8 +42,7 @@ class TTSDataset(Dataset):
): ):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can either override or create a new class as the dataset is If you need something different, you can inherit and override.
initialized by the model.
Args: Args:
outputs_per_step (int): Number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
@ -57,6 +57,8 @@ class TTSDataset(Dataset):
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None.
characters (dict): `dict` of custom text characters used for converting texts to sequences. characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
@ -81,8 +83,8 @@ class TTSDataset(Dataset):
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
the coming iterations. Defaults to None. separate file. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
@ -107,6 +109,7 @@ class TTSDataset(Dataset):
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.f0_cache_path = f0_cache_path
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
@ -123,6 +126,7 @@ class TTSDataset(Dataset):
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False self.input_seq_computed = False
self.rescue_item_idx = 1 self.rescue_item_idx = 1
self.pitch_computed = False
if use_phonemes and not os.path.isdir(phoneme_cache_path): if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True) os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose: if self.verbose:
@ -240,10 +244,14 @@ class TTSDataset(Dataset):
# TODO: find a better fix # TODO: find a better fix
return self.load_data(self.rescue_item_idx) return self.load_data(self.rescue_item_idx)
if self.compute_f0:
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
sample = { sample = {
"raw_text": raw_text, "raw_text": raw_text,
"text": text, "text": text,
"wav": wav, "wav": wav,
"pitch": pitch,
"attn": attn, "attn": attn,
"item_idx": self.items[idx][1], "item_idx": self.items[idx][1],
"speaker_name": speaker_name, "speaker_name": speaker_name,
@ -260,8 +268,8 @@ class TTSDataset(Dataset):
return phonemes return phonemes
def compute_input_seq(self, num_workers=0): def compute_input_seq(self, num_workers=0):
"""compute input sequences separately. Call it before """Compute the input sequences with multi-processing.
passing dataset to data loader.""" Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not self.use_phonemes: if not self.use_phonemes:
if self.verbose: if self.verbose:
print(" | > Computing input sequences ...") print(" | > Computing input sequences ...")
@ -306,6 +314,64 @@ class TTSDataset(Dataset):
for idx, p in enumerate(phonemes): for idx, p in enumerate(phonemes):
self.items[idx][0] = p self.items[idx][0] = p
@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
@staticmethod
def _load_or_compute_pitch(ap, wav_file, cache_path):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch
@staticmethod
def _pitch_worker(args):
item = args[0]
ap = args[1]
cache_path = args[2]
_, wav_file, *_ = item
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file):
TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
return True
return False
def compute_pitch(self, cache_path, num_workers=0):
"""Compute the input sequences with multi-processing.
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
if self.verbose:
print(" | > Computing pitch features ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
self._pitch_worker([item, self.ap, cache_path])
else:
with Pool(num_workers) as p:
_ = list(
tqdm.tqdm(
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
total=len(self.items),
)
)
def sort_and_filter_items(self, by_audio_len=False): def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range. range.
@ -367,7 +433,7 @@ class TTSDataset(Dataset):
r""" r"""
Perform preprocessing and create a final data batch: Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length 1. Sort batch instances by text-length
2. Convert Audio signal to Spectrograms. 2. Convert Audio signal to features.
3. PAD sequences wrt r. 3. PAD sequences wrt r.
4. Load to Torch. 4. Load to Torch.
""" """
@ -466,11 +532,12 @@ class TTSDataset(Dataset):
# TODO: compare perf in collate_fn vs in load_data # TODO: compare perf in collate_fn vs in load_data
pitch = None pitch = None
if self.compute_f0: if self.compute_f0:
pitch = [self.ap.compute_f0(w).astype("float32") for w in wav] pitch = [b["pitch"] for b in batch]
pitch = prepare_tensor(pitch, self.outputs_per_step) pitch = prepare_data(pitch)
pitch = pitch.transpose(0, 2, 1) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
assert mel.shape[1] == pitch.shape[1] pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
pitch = torch.FloatTensor(pitch).contiguous() else:
pitch = None
# collate attention alignments # collate attention alignments
if batch[0]["attn"] is not None: if batch[0]["attn"] is not None:
@ -478,6 +545,7 @@ class TTSDataset(Dataset):
for idx, attn in enumerate(attns): for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1] pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0] pad1 = text.shape[1] - attn.shape[0]
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
attn = np.pad(attn, [[0, pad1], [0, pad2]]) attn = np.pad(attn, [[0, pad1], [0, pad2]])
attns[idx] = attn attns[idx] = attn
attns = prepare_tensor(attns, self.outputs_per_step) attns = prepare_tensor(attns, self.outputs_per_step)
@ -499,6 +567,7 @@ class TTSDataset(Dataset):
attns, attns,
wav_padded, wav_padded,
raw_text, raw_text,
pitch,
) )
raise TypeError( raise TypeError(

View File

@ -116,6 +116,7 @@ class BaseTTS(BaseModel):
speaker_ids = batch[9] speaker_ids = batch[9]
attn_mask = batch[10] attn_mask = batch[10]
waveform = batch[11] waveform = batch[11]
pitch = batch[13]
max_text_length = torch.max(text_lengths.float()) max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float()) max_spec_length = torch.max(mel_lengths.float())
@ -162,6 +163,7 @@ class BaseTTS(BaseModel):
"max_spec_length": float(max_spec_length), "max_spec_length": float(max_spec_length),
"item_idx": item_idx, "item_idx": item_idx,
"waveform": waveform, "waveform": waveform,
"pitch": pitch,
} }
def get_data_loader( def get_data_loader(
@ -200,6 +202,7 @@ class BaseTTS(BaseModel):
text_cleaner=config.text_cleaner, text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
comnpute_f0=config.get("compute_f0", False), comnpute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items, meta_data=data_items,
ap=ap, ap=ap,
characters=config.characters, characters=config.characters,
@ -246,6 +249,14 @@ class BaseTTS(BaseModel):
# sort input sequences from short to long # sort input sequences from short to long
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# compute pitch frames and write to files.
if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]:
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# sampler for DDP # sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -9,8 +9,7 @@ import torch
from torch import nn from torch import nn
from TTS.tts.utils.data import StandardScaler from TTS.tts.utils.data import StandardScaler
from TTS.utils.yin import compute_yin
# import pyworld as pw
class TorchSTFT(nn.Module): # pylint: disable=abstract-method class TorchSTFT(nn.Module): # pylint: disable=abstract-method
@ -648,15 +647,53 @@ class AudioProcessor(object):
# frame_period=1000 * self.hop_length / self.sample_rate, # frame_period=1000 * self.hop_length / self.sample_rate,
# ) # )
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
# f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size) f0, _, _, _ = compute_yin(
f0, _, _ = librosa.pyin( x,
x.astype(np.double), self.sample_rate,
fmin=65 if self.mel_fmin == 0 else self.mel_fmin, self.win_length,
fmax=self.mel_fmax, self.hop_length,
frame_length=self.win_length, 65 if self.mel_fmin == 0 else self.mel_fmin,
sr=self.sample_rate, self.mel_fmax,
fill_na=0.0,
) )
# import pyworld as pw
# f0, _ = pw.dio(x.astype(np.float64), self.sample_rate,
# frame_period=self.hop_length / self.sample_rate * 1000)
pad = int((self.win_length / self.hop_length) / 2)
f0 = [0.0] * pad + f0 + [0.0] * pad
f0 = np.array(f0, dtype=np.float32)
# f01, _, _ = librosa.pyin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate,
# fill_na=0.0,
# )
# f02 = librosa.yin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate
# )
# spec = self.melspectrogram(x)
# from matplotlib import pyplot as plt
# plt.figure()
# plt.plot(f0, linewidth=2.5, color='red')
# plt.plot(f01, linewidth=2.5, linestyle='-.')
# plt.plot(f02, linewidth=2.5)
# plt.xlabel('time')
# plt.ylabel('F0')
# plt.savefig('save_img.png')
# # plt.figure()
# plt.imshow(spec, aspect="auto", origin="lower")
# plt.savefig('save_img2.png')
# breakpoint()
return f0 return f0
### Audio Processing ### ### Audio Processing ###