mirror of https://github.com/coqui-ai/TTS.git
Refactor TTSDataset
Return a dict by `collate` Refactor batch handling in `collate` A couple of bug fixes
This commit is contained in:
parent
6878ddfbea
commit
91a70e80b2
|
@ -77,14 +77,14 @@ def set_filename(wav_path, out_path):
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data['text']
|
||||||
text_lengths = data[1]
|
text_lengths = data['text_lengths']
|
||||||
mel_input = data[4]
|
mel_input = data['mel']
|
||||||
mel_lengths = data[5]
|
mel_lengths = data['mel_lengths']
|
||||||
item_idx = data[7]
|
item_idx = data['item_idxs']
|
||||||
d_vectors = data[8]
|
d_vectors = data['d_vectors']
|
||||||
speaker_ids = data[9]
|
speaker_ids = data['speaker_ids']
|
||||||
attn_mask = data[10]
|
attn_mask = data['attns']
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -132,9 +132,8 @@ def inference(
|
||||||
speaker_c = speaker_ids
|
speaker_c = speaker_ids
|
||||||
elif d_vectors is not None:
|
elif d_vectors is not None:
|
||||||
speaker_c = d_vectors
|
speaker_c = d_vectors
|
||||||
|
|
||||||
outputs = model.inference_with_MAS(
|
outputs = model.inference_with_MAS(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
|
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
|
||||||
)
|
)
|
||||||
model_output = outputs["model_outputs"]
|
model_output = outputs["model_outputs"]
|
||||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||||
|
|
|
@ -271,7 +271,12 @@ class Trainer:
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
|
self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer)
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
if self.args.continue_path:
|
if self.args.continue_path:
|
||||||
|
if isinstance(self.scheduler, list):
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
scheduler.last_epoch = self.restore_step
|
||||||
|
else:
|
||||||
self.scheduler.last_epoch = self.restore_step
|
self.scheduler.last_epoch = self.restore_step
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
|
|
|
@ -142,7 +142,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
enable / disable masking loss values against padded segments of samples in a batch.
|
enable / disable masking loss values against padded segments of samples in a batch.
|
||||||
|
|
||||||
sort_by_audio_len (bool):
|
sort_by_audio_len (bool):
|
||||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum sequence length to be used at training.
|
Minimum sequence length to be used at training.
|
||||||
|
@ -201,7 +201,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
batch_group_size: int = 0
|
batch_group_size: int = 0
|
||||||
loss_masking: bool = None
|
loss_masking: bool = None
|
||||||
# dataloading
|
# dataloading
|
||||||
sort_by_audio_len: bool = True
|
sort_by_audio_len: bool = False
|
||||||
min_seq_len: int = 1
|
min_seq_len: int = 1
|
||||||
max_seq_len: int = float("inf")
|
max_seq_len: int = float("inf")
|
||||||
compute_f0: bool = False
|
compute_f0: bool = False
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.tts.utils.data import _pad_data, prepare_data, prepare_stop_target, prepare_tensor
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -249,8 +249,8 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
pitch = None
|
pitch = None
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||||
pitch = self.pitch_extractor.normalize_pitch(pitch)
|
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"raw_text": raw_text,
|
"raw_text": raw_text,
|
||||||
|
@ -356,6 +356,11 @@ class TTSDataset(Dataset):
|
||||||
temp_items = new_items[offset:end_offset]
|
temp_items = new_items[offset:end_offset]
|
||||||
random.shuffle(temp_items)
|
random.shuffle(temp_items)
|
||||||
new_items[offset:end_offset] = temp_items
|
new_items[offset:end_offset] = temp_items
|
||||||
|
|
||||||
|
if len(new_items) == 0:
|
||||||
|
raise RuntimeError(" [!] No items left after filtering.")
|
||||||
|
|
||||||
|
# update items to the new sorted items
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
@ -376,6 +381,18 @@ class TTSDataset(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.load_data(idx)
|
return self.load_data(idx)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sort_batch(batch, text_lengths):
|
||||||
|
"""Sort the batch by the input text length for RNN efficiency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Batch returned by `__getitem__`.
|
||||||
|
text_lengths (List[int]): Lengths of the input character sequences.
|
||||||
|
"""
|
||||||
|
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
|
||||||
|
batch = [batch[idx] for idx in ids_sorted_decreasing]
|
||||||
|
return batch, text_lengths, ids_sorted_decreasing
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
r"""
|
r"""
|
||||||
Perform preprocessing and create a final data batch:
|
Perform preprocessing and create a final data batch:
|
||||||
|
@ -388,30 +405,27 @@ class TTSDataset(Dataset):
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
if isinstance(batch[0], collections.abc.Mapping):
|
if isinstance(batch[0], collections.abc.Mapping):
|
||||||
|
|
||||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
text_lengths = np.array([len(d["text"]) for d in batch])
|
||||||
|
|
||||||
# sort items with text input length for RNN efficiency
|
# sort items with text input length for RNN efficiency
|
||||||
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
|
batch, text_lengths, ids_sorted_decreasing = self._sort_batch(batch, text_lengths)
|
||||||
|
|
||||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
# convert list of dicts to dict of lists
|
||||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
|
||||||
raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
|
|
||||||
|
|
||||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
if self.d_vector_mapping is not None:
|
if self.d_vector_mapping is not None:
|
||||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_id_mapping:
|
if self.speaker_id_mapping:
|
||||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
# compute features
|
# compute features
|
||||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
|
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
@ -430,7 +444,7 @@ class TTSDataset(Dataset):
|
||||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with longest instance in the batch
|
# PAD sequences with longest instance in the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(batch["text"]).astype(np.int32)
|
||||||
|
|
||||||
# PAD features with longest instance
|
# PAD features with longest instance
|
||||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
@ -439,7 +453,7 @@ class TTSDataset(Dataset):
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
# convert things to pytorch
|
# convert things to pytorch
|
||||||
text_lenghts = torch.LongTensor(text_lenghts)
|
text_lengths = torch.LongTensor(text_lengths)
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
mel = torch.FloatTensor(mel).contiguous()
|
mel = torch.FloatTensor(mel).contiguous()
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
@ -453,7 +467,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
assert mel.shape[1] == linear.shape[1]
|
assert mel.shape[1] == linear.shape[1]
|
||||||
|
@ -464,11 +478,11 @@ class TTSDataset(Dataset):
|
||||||
# format waveforms
|
# format waveforms
|
||||||
wav_padded = None
|
wav_padded = None
|
||||||
if self.return_wav:
|
if self.return_wav:
|
||||||
wav_lengths = [w.shape[0] for w in wav]
|
wav_lengths = [w.shape[0] for w in batch["wav"]]
|
||||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||||
wav_lengths = torch.LongTensor(wav_lengths)
|
wav_lengths = torch.LongTensor(wav_lengths)
|
||||||
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
|
||||||
for i, w in enumerate(wav):
|
for i, w in enumerate(batch["wav"]):
|
||||||
mel_length = mel_lengths_adjusted[i]
|
mel_length = mel_lengths_adjusted[i]
|
||||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||||
w = w[: mel_length * self.ap.hop_length]
|
w = w[: mel_length * self.ap.hop_length]
|
||||||
|
@ -477,18 +491,16 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
# compute f0
|
# compute f0
|
||||||
# TODO: compare perf in collate_fn vs in load_data
|
# TODO: compare perf in collate_fn vs in load_data
|
||||||
pitch = None
|
|
||||||
if self.compute_f0:
|
if self.compute_f0:
|
||||||
pitch = [b["pitch"] for b in batch]
|
pitch = prepare_data(batch["pitch"])
|
||||||
pitch = prepare_data(pitch)
|
|
||||||
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
|
||||||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||||
else:
|
else:
|
||||||
pitch = None
|
pitch = None
|
||||||
|
|
||||||
# collate attention alignments
|
# collate attention alignments
|
||||||
if batch[0]["attn"] is not None:
|
if batch["attn"][0] is not None:
|
||||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
|
||||||
for idx, attn in enumerate(attns):
|
for idx, attn in enumerate(attns):
|
||||||
pad2 = mel.shape[1] - attn.shape[1]
|
pad2 = mel.shape[1] - attn.shape[1]
|
||||||
pad1 = text.shape[1] - attn.shape[0]
|
pad1 = text.shape[1] - attn.shape[0]
|
||||||
|
@ -502,18 +514,18 @@ class TTSDataset(Dataset):
|
||||||
# TODO: return dictionary
|
# TODO: return dictionary
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"text_lengths": text_lenghts,
|
"text_lengths": text_lengths,
|
||||||
"speaker_names": speaker_names,
|
"speaker_names": batch["speaker_name"],
|
||||||
"linear": linear,
|
"linear": linear,
|
||||||
"mel": mel,
|
"mel": mel,
|
||||||
"mel_lengths": mel_lengths,
|
"mel_lengths": mel_lengths,
|
||||||
"stop_targets": stop_targets,
|
"stop_targets": stop_targets,
|
||||||
"item_idxs": item_idxs,
|
"item_idxs": batch["item_idx"],
|
||||||
"d_vectors": d_vectors,
|
"d_vectors": d_vectors,
|
||||||
"speaker_ids": speaker_ids,
|
"speaker_ids": speaker_ids,
|
||||||
"attns": attns,
|
"attns": attns,
|
||||||
"waveform": wav_padded,
|
"waveform": wav_padded,
|
||||||
"raw_text": raw_text,
|
"raw_text": batch["raw_text"],
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,13 +579,20 @@ class PitchExtractor:
|
||||||
|
|
||||||
def normalize_pitch(self, pitch):
|
def normalize_pitch(self, pitch):
|
||||||
zero_idxs = np.where(pitch == 0.0)[0]
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
pitch -= self.mean
|
pitch = pitch - self.mean
|
||||||
pitch /= self.std
|
pitch = pitch / self.std
|
||||||
|
pitch[zero_idxs] = 0.0
|
||||||
|
return pitch
|
||||||
|
|
||||||
|
def denormalize_pitch(self, pitch):
|
||||||
|
zero_idxs = np.where(pitch == 0.0)[0]
|
||||||
|
pitch *= self.std
|
||||||
|
pitch += self.mean
|
||||||
pitch[zero_idxs] = 0.0
|
pitch[zero_idxs] = 0.0
|
||||||
return pitch
|
return pitch
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
def load_or_compute_pitch(ap, wav_file, cache_path):
|
||||||
"""
|
"""
|
||||||
compute pitch and return a numpy array of pitch values
|
compute pitch and return a numpy array of pitch values
|
||||||
"""
|
"""
|
||||||
|
@ -582,7 +601,7 @@ class PitchExtractor:
|
||||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||||
else:
|
else:
|
||||||
pitch = np.load(pitch_file)
|
pitch = np.load(pitch_file)
|
||||||
return pitch
|
return pitch.astype(np.float32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pitch_worker(args):
|
def _pitch_worker(args):
|
||||||
|
@ -596,7 +615,7 @@ class PitchExtractor:
|
||||||
return pitch
|
return pitch
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def compute_pitch(self, cache_path, num_workers=0):
|
def compute_pitch(self, ap, cache_path, num_workers=0):
|
||||||
"""Compute the input sequences with multi-processing.
|
"""Compute the input sequences with multi-processing.
|
||||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path):
|
||||||
|
@ -607,12 +626,12 @@ class PitchExtractor:
|
||||||
if num_workers == 0:
|
if num_workers == 0:
|
||||||
pitch_vecs = []
|
pitch_vecs = []
|
||||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
|
pitch_vecs += [self._pitch_worker([item, ap, cache_path])]
|
||||||
else:
|
else:
|
||||||
with Pool(num_workers) as p:
|
with Pool(num_workers) as p:
|
||||||
pitch_vecs = list(
|
pitch_vecs = list(
|
||||||
tqdm.tqdm(
|
tqdm.tqdm(
|
||||||
p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]),
|
||||||
total=len(self.items),
|
total=len(self.items),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -623,5 +642,5 @@ class PitchExtractor:
|
||||||
def load_pitch_stats(self, cache_path):
|
def load_pitch_stats(self, cache_path):
|
||||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||||
stats = np.load(stats_path, allow_pickle=True).item()
|
stats = np.load(stats_path, allow_pickle=True).item()
|
||||||
self.mean = stats["mean"]
|
self.mean = stats["mean"].astype(np.float32)
|
||||||
self.std = stats["std"]
|
self.std = stats["std"].astype(np.float32)
|
||||||
|
|
Loading…
Reference in New Issue