mirror of https://github.com/coqui-ai/TTS.git
Config chnages for TWEB
This commit is contained in:
parent
18200a89b1
commit
5b0b290cee
11
config.json
11
config.json
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1025,
|
"num_freq": 1025,
|
||||||
"sample_rate": 20000,
|
"sample_rate": 22050,
|
||||||
"frame_length_ms": 50,
|
"frame_length_ms": 50,
|
||||||
"frame_shift_ms": 12.5,
|
"frame_shift_ms": 12.5,
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
|
@ -15,16 +15,17 @@
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":32,
|
||||||
"r": 3,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.2,
|
||||||
|
|
||||||
|
"dataset": "TWEB"
|
||||||
|
"data_path": "/run/shm/erogol/BibleSpeech/",
|
||||||
|
"min_seq_len": 0,
|
||||||
"num_loader_workers": 8,
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
"checkpoint": true,
|
"checkpoint": true,
|
||||||
"save_step": 376,
|
"save_step": 376,
|
||||||
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
|
||||||
"min_seq_len": 0,
|
|
||||||
"output_path": "/data/shared/erogol_models/"
|
"output_path": "/data/shared/erogol_models/"
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from TTS.utils.text import text_to_sequence
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.data import (prepare_data, pad_per_step,
|
||||||
|
prepare_tensor, prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
|
class TWEBDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
|
min_seq_len=0):
|
||||||
|
|
||||||
|
with open(csv_file, "r") as f:
|
||||||
|
self.frames = [line.split('\t') for line in f]
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.outputs_per_step = outputs_per_step
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.cleaners = text_cleaner
|
||||||
|
self.min_seq_len = min_seq_len
|
||||||
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||||
|
print(" > Reading TWEB from - {}".format(root_dir))
|
||||||
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
self._sort_frames()
|
||||||
|
|
||||||
|
def load_wav(self, filename):
|
||||||
|
try:
|
||||||
|
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||||
|
return audio
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
|
def _sort_frames(self):
|
||||||
|
r"""Sort sequences in ascending order"""
|
||||||
|
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||||
|
|
||||||
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||||
|
|
||||||
|
idxs = np.argsort(lengths)
|
||||||
|
new_frames = []
|
||||||
|
ignored = []
|
||||||
|
for i, idx in enumerate(idxs):
|
||||||
|
length = lengths[idx]
|
||||||
|
if length < self.min_seq_len:
|
||||||
|
ignored.append(idx)
|
||||||
|
else:
|
||||||
|
new_frames.append(self.frames[idx])
|
||||||
|
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||||
|
len(ignored), self.min_seq_len))
|
||||||
|
self.frames = new_frames
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
wav_name = os.path.join(self.root_dir,
|
||||||
|
self.frames[idx][0]) + '.wav'
|
||||||
|
text = self.frames[idx][1]
|
||||||
|
text = np.asarray(text_to_sequence(
|
||||||
|
text, [self.cleaners]), dtype=np.int32)
|
||||||
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def get_dummy_data(self):
|
||||||
|
r"""Get a dummy input for testing"""
|
||||||
|
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
r"""
|
||||||
|
Perform preprocessing and create a final data batch:
|
||||||
|
1. PAD sequences with the longest sequence in the batch
|
||||||
|
2. Convert Audio signal to Spectrograms.
|
||||||
|
3. PAD sequences that can be divided by r.
|
||||||
|
4. Convert Numpy to Torch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
keys = list()
|
||||||
|
|
||||||
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
|
text = [d['text'] for d in batch]
|
||||||
|
|
||||||
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [np.array([0.]*(mel_len-1))
|
||||||
|
for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(
|
||||||
|
stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
|
# PAD sequences with largest length of the batch
|
||||||
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
|
# PAD features with largest length + a zero frame
|
||||||
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
|
# B x T x D
|
||||||
|
linear = linear.transpose(0, 2, 1)
|
||||||
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
# convert things to pytorch
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
linear = torch.FloatTensor(linear)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
|
found {}"
|
||||||
|
.format(type(batch[0]))))
|
Loading…
Reference in New Issue