wavernn dataloader handling for short samples and mixed precision training

This commit is contained in:
erogol 2020-10-28 12:31:01 +01:00
parent f4b8170bd1
commit 9d0ae2bfb4
3 changed files with 53 additions and 29 deletions

View File

@ -94,6 +94,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
scaler = torch.cuda.amp.GradScaler()
# train loop
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -101,24 +102,43 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
loader_time = time.time() - end_time
global_step += 1
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
# compute losses
loss = criterion(y_hat, y_coarse)
if loss.item() is None:
raise RuntimeError(" [!] None loss. Exiting ...")
optimizer.zero_grad()
loss.backward()
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
optimizer.step()
if c.mixed_precision:
# mixed precision training
with torch.cuda.amp.autocast():
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
# compute losses
loss = criterion(y_hat, y_coarse)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
# full precision training
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
# compute losses
loss = criterion(y_hat, y_coarse)
if loss.item() is None:
raise RuntimeError(" [!] None loss. Exiting ...")
loss.backward()
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
optimizer.step()
if scheduler is not None:
scheduler.step()

View File

@ -1,6 +1,6 @@
{
"run_name": "wavernn_test",
"run_description": "wavernn_test training",
"run_name": "wavernn_librittts",
"run_description": "wavernn libritts training from LJSpeech model",
// AUDIO PARAMETERS
"audio": {
@ -10,7 +10,7 @@
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Silence trimming
@ -58,14 +58,15 @@
// DATASET
//"use_gta": true, // use computed gta features from the tts model
"data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/", // path containing training wav files
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files
"feature_path": null, // path containing computed features from wav files if null compute them
"seq_len": 1280, // has to be devideable by hop_length
"padding": 2, // pad the input for resnet to see wider input length
// TRAINING
"batch_size": 64, // Batch size for training.
"batch_size": 256, // Batch size for training.
"epochs": 10000, // total number of epochs to train.
"mixed_precision": true, // enable/ disable mixed precision training
// VALIDATION
"run_eval": true,

View File

@ -26,12 +26,15 @@ class WaveRNNDataset(Dataset):
self.item_list = items
self.seq_len = seq_len
self.hop_len = hop_len
self.mel_len = seq_len // hop_len
self.pad = pad
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
assert self.seq_len % self.hop_len == 0
def __len__(self):
return len(self.item_list)
@ -48,13 +51,12 @@ class WaveRNNDataset(Dataset):
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio)
if mel.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
@ -68,7 +70,7 @@ class WaveRNNDataset(Dataset):
wavpath, feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < 5:
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
@ -80,12 +82,13 @@ class WaveRNNDataset(Dataset):
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
return mel, x_input
return mel, x_input, wavpath
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
max_offsets = [x[0].shape[-1] -
(mel_win + 2 * self.pad) for x in batch]
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
sig_offsets = [(offset + self.pad) *
self.hop_len for offset in mel_offsets]