mirror of https://github.com/coqui-ai/TTS.git
wavernn dataloader handling for short samples and mixed precision training
This commit is contained in:
parent
f4b8170bd1
commit
9d0ae2bfb4
|
@ -94,6 +94,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
# train loop
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
@ -101,24 +102,43 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
y_hat = model(x_input, mels)
|
||||
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
if loss.item() is None:
|
||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
if c.mixed_precision:
|
||||
# mixed precision training
|
||||
with torch.cuda.amp.autocast():
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# full precision training
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
if loss.item() is None:
|
||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
loss.backward()
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "wavernn_test",
|
||||
"run_description": "wavernn_test training",
|
||||
"run_name": "wavernn_librittts",
|
||||
"run_description": "wavernn libritts training from LJSpeech model",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio": {
|
||||
|
@ -10,7 +10,7 @@
|
|||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
// Silence trimming
|
||||
|
@ -58,14 +58,15 @@
|
|||
|
||||
// DATASET
|
||||
//"use_gta": true, // use computed gta features from the tts model
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/", // path containing training wav files
|
||||
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files
|
||||
"feature_path": null, // path containing computed features from wav files if null compute them
|
||||
"seq_len": 1280, // has to be devideable by hop_length
|
||||
"padding": 2, // pad the input for resnet to see wider input length
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 64, // Batch size for training.
|
||||
"batch_size": 256, // Batch size for training.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"mixed_precision": true, // enable/ disable mixed precision training
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
|
|
|
@ -26,12 +26,15 @@ class WaveRNNDataset(Dataset):
|
|||
self.item_list = items
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.mel_len = seq_len // hop_len
|
||||
self.pad = pad
|
||||
self.mode = mode
|
||||
self.mulaw = mulaw
|
||||
self.is_training = is_training
|
||||
self.verbose = verbose
|
||||
|
||||
assert self.seq_len % self.hop_len == 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
|
@ -48,13 +51,12 @@ class WaveRNNDataset(Dataset):
|
|||
|
||||
wavpath = self.item_list[index]
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||
if audio.shape[0] < min_audio_len:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
if mel.shape[-1] < 5:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
|
@ -68,7 +70,7 @@ class WaveRNNDataset(Dataset):
|
|||
wavpath, feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
||||
if mel.shape[-1] < 5:
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
feat_path = self.item_list[index]
|
||||
|
@ -80,12 +82,13 @@ class WaveRNNDataset(Dataset):
|
|||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
return mel, x_input
|
||||
return mel, x_input, wavpath
|
||||
|
||||
def collate(self, batch):
|
||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||
max_offsets = [x[0].shape[-1] -
|
||||
(mel_win + 2 * self.pad) for x in batch]
|
||||
|
||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||
sig_offsets = [(offset + self.pad) *
|
||||
self.hop_len for offset in mel_offsets]
|
||||
|
|
Loading…
Reference in New Issue