mirror of https://github.com/coqui-ai/TTS.git
wavernn dataloader handling for short samples and mixed precision training
This commit is contained in:
parent
f4b8170bd1
commit
9d0ae2bfb4
|
@ -94,6 +94,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
# train loop
|
# train loop
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -101,19 +102,38 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
y_hat = model(x_input, mels)
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if c.mixed_precision:
|
||||||
|
# mixed precision training
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
y_hat = model(x_input, mels)
|
||||||
|
if isinstance(model.mode, int):
|
||||||
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
y_coarse = y_coarse.float()
|
||||||
|
y_coarse = y_coarse.unsqueeze(-1)
|
||||||
|
# compute losses
|
||||||
|
loss = criterion(y_hat, y_coarse)
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
if c.grad_clip > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), c.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
# full precision training
|
||||||
|
y_hat = model(x_input, mels)
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
y_coarse = y_coarse.float()
|
y_coarse = y_coarse.float()
|
||||||
y_coarse = y_coarse.unsqueeze(-1)
|
y_coarse = y_coarse.unsqueeze(-1)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(y_hat, y_coarse)
|
loss = criterion(y_hat, y_coarse)
|
||||||
if loss.item() is None:
|
if loss.item() is None:
|
||||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if c.grad_clip > 0:
|
if c.grad_clip > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "wavernn_test",
|
"run_name": "wavernn_librittts",
|
||||||
"run_description": "wavernn_test training",
|
"run_description": "wavernn libritts training from LJSpeech model",
|
||||||
|
|
||||||
// AUDIO PARAMETERS
|
// AUDIO PARAMETERS
|
||||||
"audio": {
|
"audio": {
|
||||||
|
@ -10,7 +10,7 @@
|
||||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||||
// Silence trimming
|
// Silence trimming
|
||||||
|
@ -58,14 +58,15 @@
|
||||||
|
|
||||||
// DATASET
|
// DATASET
|
||||||
//"use_gta": true, // use computed gta features from the tts model
|
//"use_gta": true, // use computed gta features from the tts model
|
||||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/", // path containing training wav files
|
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files
|
||||||
"feature_path": null, // path containing computed features from wav files if null compute them
|
"feature_path": null, // path containing computed features from wav files if null compute them
|
||||||
"seq_len": 1280, // has to be devideable by hop_length
|
"seq_len": 1280, // has to be devideable by hop_length
|
||||||
"padding": 2, // pad the input for resnet to see wider input length
|
"padding": 2, // pad the input for resnet to see wider input length
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 64, // Batch size for training.
|
"batch_size": 256, // Batch size for training.
|
||||||
"epochs": 10000, // total number of epochs to train.
|
"epochs": 10000, // total number of epochs to train.
|
||||||
|
"mixed_precision": true, // enable/ disable mixed precision training
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
|
|
@ -26,12 +26,15 @@ class WaveRNNDataset(Dataset):
|
||||||
self.item_list = items
|
self.item_list = items
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.hop_len = hop_len
|
self.hop_len = hop_len
|
||||||
|
self.mel_len = seq_len // hop_len
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.mulaw = mulaw
|
self.mulaw = mulaw
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
assert self.seq_len % self.hop_len == 0
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.item_list)
|
return len(self.item_list)
|
||||||
|
|
||||||
|
@ -48,13 +51,12 @@ class WaveRNNDataset(Dataset):
|
||||||
|
|
||||||
wavpath = self.item_list[index]
|
wavpath = self.item_list[index]
|
||||||
audio = self.ap.load_wav(wavpath)
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||||
|
if audio.shape[0] < min_audio_len:
|
||||||
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
|
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||||
mel = self.ap.melspectrogram(audio)
|
mel = self.ap.melspectrogram(audio)
|
||||||
|
|
||||||
if mel.shape[-1] < 5:
|
|
||||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
|
||||||
self.item_list[index] = self.item_list[index + 1]
|
|
||||||
audio = self.ap.load_wav(wavpath)
|
|
||||||
mel = self.ap.melspectrogram(audio)
|
|
||||||
if self.mode in ["gauss", "mold"]:
|
if self.mode in ["gauss", "mold"]:
|
||||||
x_input = audio
|
x_input = audio
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
|
@ -68,7 +70,7 @@ class WaveRNNDataset(Dataset):
|
||||||
wavpath, feat_path = self.item_list[index]
|
wavpath, feat_path = self.item_list[index]
|
||||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||||
|
|
||||||
if mel.shape[-1] < 5:
|
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
self.item_list[index] = self.item_list[index + 1]
|
self.item_list[index] = self.item_list[index + 1]
|
||||||
feat_path = self.item_list[index]
|
feat_path = self.item_list[index]
|
||||||
|
@ -80,12 +82,13 @@ class WaveRNNDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||||
|
|
||||||
return mel, x_input
|
return mel, x_input, wavpath
|
||||||
|
|
||||||
def collate(self, batch):
|
def collate(self, batch):
|
||||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||||
max_offsets = [x[0].shape[-1] -
|
max_offsets = [x[0].shape[-1] -
|
||||||
(mel_win + 2 * self.pad) for x in batch]
|
(mel_win + 2 * self.pad) for x in batch]
|
||||||
|
|
||||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||||
sig_offsets = [(offset + self.pad) *
|
sig_offsets = [(offset + self.pad) *
|
||||||
self.hop_len for offset in mel_offsets]
|
self.hop_len for offset in mel_offsets]
|
||||||
|
|
Loading…
Reference in New Issue