mirror of https://github.com/coqui-ai/TTS.git
Cache pitch features
Cache the features at the beginning of `BaseTTS` training.
This commit is contained in:
parent
7590c7db7a
commit
d085642ac1
|
@ -9,7 +9,7 @@ import torch
|
|||
import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -23,6 +23,7 @@ class TTSDataset(Dataset):
|
|||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
compute_f0: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
characters: Dict = None,
|
||||
custom_symbols: List = None,
|
||||
add_blank: bool = False,
|
||||
|
@ -41,8 +42,7 @@ class TTSDataset(Dataset):
|
|||
):
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
If you need something different, you can either override or create a new class as the dataset is
|
||||
initialized by the model.
|
||||
If you need something different, you can inherit and override.
|
||||
|
||||
Args:
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
@ -57,6 +57,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||
|
||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||
|
@ -81,8 +83,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||
|
||||
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
|
||||
the coming iterations. Defaults to None.
|
||||
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
|
||||
separate file. Defaults to None.
|
||||
|
||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||
|
||||
|
@ -107,6 +109,7 @@ class TTSDataset(Dataset):
|
|||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.compute_f0 = compute_f0
|
||||
self.f0_cache_path = f0_cache_path
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
|
@ -123,6 +126,7 @@ class TTSDataset(Dataset):
|
|||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
self.pitch_computed = False
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if self.verbose:
|
||||
|
@ -240,10 +244,14 @@ class TTSDataset(Dataset):
|
|||
# TODO: find a better fix
|
||||
return self.load_data(self.rescue_item_idx)
|
||||
|
||||
if self.compute_f0:
|
||||
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
"text": text,
|
||||
"wav": wav,
|
||||
"pitch": pitch,
|
||||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
|
@ -260,8 +268,8 @@ class TTSDataset(Dataset):
|
|||
return phonemes
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
"""compute input sequences separately. Call it before
|
||||
passing dataset to data loader."""
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not self.use_phonemes:
|
||||
if self.verbose:
|
||||
print(" | > Computing input sequences ...")
|
||||
|
@ -306,6 +314,64 @@ class TTSDataset(Dataset):
|
|||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
return pitch_file
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
pitch = ap.compute_f0(wav)
|
||||
if pitch_file:
|
||||
np.save(pitch_file, pitch)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _pitch_worker(args):
|
||||
item = args[0]
|
||||
ap = args[1]
|
||||
cache_path = args[2]
|
||||
_, wav_file, *_ = item
|
||||
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
return True
|
||||
return False
|
||||
|
||||
def compute_pitch(self, cache_path, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Computing pitch features ...")
|
||||
if num_workers == 0:
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
self._pitch_worker([item, self.ap, cache_path])
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
_ = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
|
||||
def sort_and_filter_items(self, by_audio_len=False):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
|
@ -367,7 +433,7 @@ class TTSDataset(Dataset):
|
|||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. Sort batch instances by text-length
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
2. Convert Audio signal to features.
|
||||
3. PAD sequences wrt r.
|
||||
4. Load to Torch.
|
||||
"""
|
||||
|
@ -466,11 +532,12 @@ class TTSDataset(Dataset):
|
|||
# TODO: compare perf in collate_fn vs in load_data
|
||||
pitch = None
|
||||
if self.compute_f0:
|
||||
pitch = [self.ap.compute_f0(w).astype("float32") for w in wav]
|
||||
pitch = prepare_tensor(pitch, self.outputs_per_step)
|
||||
pitch = pitch.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == pitch.shape[1]
|
||||
pitch = torch.FloatTensor(pitch).contiguous()
|
||||
pitch = [b["pitch"] for b in batch]
|
||||
pitch = prepare_data(pitch)
|
||||
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
pitch = None
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]["attn"] is not None:
|
||||
|
@ -478,6 +545,7 @@ class TTSDataset(Dataset):
|
|||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
|
||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||
attns[idx] = attn
|
||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||
|
@ -499,6 +567,7 @@ class TTSDataset(Dataset):
|
|||
attns,
|
||||
wav_padded,
|
||||
raw_text,
|
||||
pitch,
|
||||
)
|
||||
|
||||
raise TypeError(
|
||||
|
|
|
@ -116,6 +116,7 @@ class BaseTTS(BaseModel):
|
|||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
waveform = batch[11]
|
||||
pitch = batch[13]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -162,6 +163,7 @@ class BaseTTS(BaseModel):
|
|||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
|
@ -200,6 +202,7 @@ class BaseTTS(BaseModel):
|
|||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
comnpute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
|
@ -246,6 +249,14 @@ class BaseTTS(BaseModel):
|
|||
# sort input sequences from short to long
|
||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||
|
||||
# compute pitch frames and write to files.
|
||||
if config.compute_f0 and not os.path.exists(config.f0_cache_path) and rank in [None, 0]:
|
||||
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.data import StandardScaler
|
||||
|
||||
# import pyworld as pw
|
||||
from TTS.utils.yin import compute_yin
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
|
@ -648,15 +647,53 @@ class AudioProcessor(object):
|
|||
# frame_period=1000 * self.hop_length / self.sample_rate,
|
||||
# )
|
||||
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
||||
# f0 = compute_yin(, self.sample_rate, self.hop_length, self.fft_size)
|
||||
f0, _, _ = librosa.pyin(
|
||||
x.astype(np.double),
|
||||
fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
frame_length=self.win_length,
|
||||
sr=self.sample_rate,
|
||||
fill_na=0.0,
|
||||
f0, _, _, _ = compute_yin(
|
||||
x,
|
||||
self.sample_rate,
|
||||
self.win_length,
|
||||
self.hop_length,
|
||||
65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||
self.mel_fmax,
|
||||
)
|
||||
# import pyworld as pw
|
||||
# f0, _ = pw.dio(x.astype(np.float64), self.sample_rate,
|
||||
# frame_period=self.hop_length / self.sample_rate * 1000)
|
||||
pad = int((self.win_length / self.hop_length) / 2)
|
||||
f0 = [0.0] * pad + f0 + [0.0] * pad
|
||||
f0 = np.array(f0, dtype=np.float32)
|
||||
|
||||
# f01, _, _ = librosa.pyin(
|
||||
# x,
|
||||
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||
# fmax=self.mel_fmax,
|
||||
# frame_length=self.win_length,
|
||||
# sr=self.sample_rate,
|
||||
# fill_na=0.0,
|
||||
# )
|
||||
|
||||
# f02 = librosa.yin(
|
||||
# x,
|
||||
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
|
||||
# fmax=self.mel_fmax,
|
||||
# frame_length=self.win_length,
|
||||
# sr=self.sample_rate
|
||||
# )
|
||||
|
||||
# spec = self.melspectrogram(x)
|
||||
|
||||
# from matplotlib import pyplot as plt
|
||||
# plt.figure()
|
||||
# plt.plot(f0, linewidth=2.5, color='red')
|
||||
# plt.plot(f01, linewidth=2.5, linestyle='-.')
|
||||
# plt.plot(f02, linewidth=2.5)
|
||||
# plt.xlabel('time')
|
||||
# plt.ylabel('F0')
|
||||
# plt.savefig('save_img.png')
|
||||
|
||||
# # plt.figure()
|
||||
# plt.imshow(spec, aspect="auto", origin="lower")
|
||||
# plt.savefig('save_img2.png')
|
||||
# breakpoint()
|
||||
return f0
|
||||
|
||||
### Audio Processing ###
|
||||
|
|
Loading…
Reference in New Issue