mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset
Return a dict by `collate` Refactor batch handling in `collate` A couple of bug fixes
This commit is contained in:
parent
6878ddfbea
commit
91a70e80b2
|
@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
|
|||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
d_vectors = data[8]
|
||||
speaker_ids = data[9]
|
||||
attn_mask = data[10]
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
item_idx = data['item_idxs']
|
||||
d_vectors = data['d_vectors']
|
||||
speaker_ids = data['speaker_ids']
|
||||
attn_mask = data['attns']
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
|
@ -132,9 +132,8 @@ def inference(
|
|||
speaker_c = speaker_ids
|
||||
elif d_vectors is not None:
|
||||
speaker_c = d_vectors
|
||||
|
||||
outputs = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
|
||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
|
||||
)
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
|
|
@ -271,8 +271,13 @@ class Trainer:
|
|||
# setup scheduler
|
||||
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
|
||||
|
||||
if self.args.continue_path:
|
||||
self.scheduler.last_epoch = self.restore_step
|
||||
if self.scheduler is not None:
|
||||
if self.args.continue_path:
|
||||
if isinstance(self.scheduler, list):
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.last_epoch = self.restore_step
|
||||
else:
|
||||
self.scheduler.last_epoch = self.restore_step
|
||||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
|
|
|
@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequence length to be used at training.
|
||||
|
@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
sort_by_audio_len: bool = True
|
||||
sort_by_audio_len: bool = False
|
||||
min_seq_len: int = 1
|
||||
max_seq_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -249,8 +249,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
pitch = None
|
||||
if self.compute_f0:
|
||||
pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.normalize_pitch(pitch)
|
||||
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
|
@ -356,6 +356,11 @@ class TTSDataset(Dataset):
|
|||
temp_items = new_items[offset:end_offset]
|
||||
random.shuffle(temp_items)
|
||||
new_items[offset:end_offset] = temp_items
|
||||
|
||||
if len(new_items) == 0:
|
||||
raise RuntimeError(" [!] No items left after filtering.")
|
||||
|
||||
# update items to the new sorted items
|
||||
self.items = new_items
|
||||
|
||||
# logging
|
||||
|
@ -376,6 +381,18 @@ class TTSDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
return self.load_data(idx)
|
||||
|
||||
@staticmethod
|
||||
def _sort_batch(batch, text_lengths):
|
||||
"""Sort the batch by the input text length for RNN efficiency.
|
||||
|
||||
Args:
|
||||
batch (Dict): Batch returned by `__getitem__`.
|
||||
text_lengths (List[int]): Lengths of the input character sequences.
|
||||
"""
|
||||
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
|
||||
batch = [batch[idx] for idx in ids_sorted_decreasing]
|
||||
return batch, text_lengths, ids_sorted_decreasing
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
|
@ -388,30 +405,27 @@ class TTSDataset(Dataset):
|
|||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.abc.Mapping):
|
||||
|
||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||
text_lengths = np.array([len(d["text"]) for d in batch])
|
||||
|
||||
# sort items with text input length for RNN efficiency
|
||||
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
|
||||
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
|
||||
|
||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
||||
# convert list of dicts to dict of lists
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get pre-computed d-vectors
|
||||
if self.d_vector_mapping is not None:
|
||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
||||
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
d_vectors = None
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_id_mapping:
|
||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
|
||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||
else:
|
||||
speaker_ids = None
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
|
@ -430,7 +444,7 @@ class TTSDataset(Dataset):
|
|||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||
|
||||
# PAD sequences with longest instance in the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
text = prepare_data(batch["text"]).astype(np.int32)
|
||||
|
||||
# PAD features with longest instance
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
|
@ -439,7 +453,7 @@ class TTSDataset(Dataset):
|
|||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text_lengths = torch.LongTensor(text_lengths)
|
||||
text = torch.LongTensor(text)
|
||||
mel = torch.FloatTensor(mel).contiguous()
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
|
@ -453,7 +467,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
|
@ -464,11 +478,11 @@ class TTSDataset(Dataset):
|
|||
# format waveforms
|
||||
wav_padded = None
|
||||
if self.return_wav:
|
||||
wav_lengths = [w.shape[0] for w in wav]
|
||||
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||
wav_lengths = torch.LongTensor(wav_lengths)
|
||||
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
||||
for i, w in enumerate(wav):
|
||||
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
||||
for i, w in enumerate(batch["wav"]):
|
||||
mel_length = mel_lengths_adjusted[i]
|
||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||
w = w[: mel_length * self.ap.hop_length]
|
||||
|
@ -477,18 +491,16 @@ class TTSDataset(Dataset):
|
|||
|
||||
# compute f0
|
||||
# TODO: compare perf in collate_fn vs in load_data
|
||||
pitch = None
|
||||
if self.compute_f0:
|
||||
pitch = [b["pitch"] for b in batch]
|
||||
pitch = prepare_data(pitch)
|
||||
pitch = prepare_data(batch["pitch"])
|
||||
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
pitch = None
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]["attn"] is not None:
|
||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||
if batch["attn"][0] is not None:
|
||||
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
|
@ -502,18 +514,18 @@ class TTSDataset(Dataset):
|
|||
# TODO: return dictionary
|
||||
return {
|
||||
"text": text,
|
||||
"text_lengths": text_lenghts,
|
||||
"speaker_names": speaker_names,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": batch["speaker_name"],
|
||||
"linear": linear,
|
||||
"mel": mel,
|
||||
"mel_lengths": mel_lengths,
|
||||
"stop_targets": stop_targets,
|
||||
"item_idxs": item_idxs,
|
||||
"item_idxs": batch["item_idx"],
|
||||
"d_vectors": d_vectors,
|
||||
"speaker_ids": speaker_ids,
|
||||
"attns": attns,
|
||||
"waveform": wav_padded,
|
||||
"raw_text": raw_text,
|
||||
"raw_text": batch["raw_text"],
|
||||
"pitch": pitch,
|
||||
}
|
||||
|
||||
|
@ -567,13 +579,20 @@ class PitchExtractor:
|
|||
|
||||
def normalize_pitch(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch -= self.mean
|
||||
pitch /= self.std
|
||||
pitch = pitch - self.mean
|
||||
pitch = pitch / self.std
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
def denormalize_pitch(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch *= self.std
|
||||
pitch += self.mean
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
def load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
|
@ -582,7 +601,7 @@ class PitchExtractor:
|
|||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch
|
||||
return pitch.astype(np.float32)
|
||||
|
||||
@staticmethod
|
||||
def _pitch_worker(args):
|
||||
|
@ -596,7 +615,7 @@ class PitchExtractor:
|
|||
return pitch
|
||||
return None
|
||||
|
||||
def compute_pitch(self, cache_path, num_workers=0):
|
||||
def compute_pitch(self, ap, cache_path, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not os.path.exists(cache_path):
|
||||
|
@ -607,12 +626,12 @@ class PitchExtractor:
|
|||
if num_workers == 0:
|
||||
pitch_vecs = []
|
||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
|
||||
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
pitch_vecs = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
||||
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
|
@ -623,5 +642,5 @@ class PitchExtractor:
|
|||
def load_pitch_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"]
|
||||
self.std = stats["std"]
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
self.std = stats["std"].astype(np.float32)
|
||||
|
|
Loading…
Reference in New Issue