mirror of https://github.com/coqui-ai/TTS.git
Formatting changes and distributed training
This commit is contained in:
parent
dce1715e0f
commit
bf5f18d11e
12
config.json
12
config.json
|
@ -3,7 +3,6 @@
|
||||||
"model_description": "Queue memory and change lower r incrementatlly",
|
"model_description": "Queue memory and change lower r incrementatlly",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
"audio_processor": "audio", // to use dictate different audio processors, if available.
|
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
"num_mels": 80, // size of the mel spec frame.
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
|
@ -25,6 +24,11 @@
|
||||||
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"distributed":{
|
||||||
|
"backend": "nccl",
|
||||||
|
"url": "tcp:\/\/localhost:54321"
|
||||||
|
},
|
||||||
|
|
||||||
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
|
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
|
||||||
"text_cleaner": "phoneme_cleaners",
|
"text_cleaner": "phoneme_cleaners",
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
|
@ -37,14 +41,16 @@
|
||||||
|
|
||||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":32,
|
||||||
"r": 2, // Number of frames to predict for step.
|
"r": 5, // Number of frames to predict for step.
|
||||||
"wd": 0.00001, // Weight decay weight.
|
"wd": 0.00001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"save_step": 5000, // Number of training steps expected to save traning stats and checkpoints.
|
"save_step": 5000, // Number of training steps expected to save traning stats and checkpoints.
|
||||||
"print_step": 50, // Number of steps to log traning on console.
|
"print_step": 50, // Number of steps to log traning on console.
|
||||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
"batch_group_size": 8, //Number of batches to shuffle after bucketing.
|
||||||
|
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
"test_delay_epochs": 100, //Until attention is aligned, testing only wastes computation time.
|
||||||
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
||||||
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
|
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
|
||||||
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
|
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||||
|
|
|
@ -25,7 +25,8 @@ class MyDataset(Dataset):
|
||||||
cached=False,
|
cached=False,
|
||||||
use_phonemes=True,
|
use_phonemes=True,
|
||||||
phoneme_cache_path=None,
|
phoneme_cache_path=None,
|
||||||
phoneme_language="en-us"):
|
phoneme_language="en-us",
|
||||||
|
verbose=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
root_path (str): root path for the data folder.
|
root_path (str): root path for the data folder.
|
||||||
|
@ -47,6 +48,7 @@ class MyDataset(Dataset):
|
||||||
phoneme_cache_path (str): path to cache phoneme features.
|
phoneme_cache_path (str): path to cache phoneme features.
|
||||||
phoneme_language (str): one the languages from
|
phoneme_language (str): one the languages from
|
||||||
https://github.com/bootphon/phonemizer#languages
|
https://github.com/bootphon/phonemizer#languages
|
||||||
|
verbose (bool): print diagnostic information.
|
||||||
"""
|
"""
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
|
@ -61,16 +63,17 @@ class MyDataset(Dataset):
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.phoneme_language = phoneme_language
|
self.phoneme_language = phoneme_language
|
||||||
|
self.verbose = verbose
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path)
|
os.makedirs(phoneme_cache_path)
|
||||||
print(" > DataLoader initialization")
|
if self.verbose:
|
||||||
print(" | > Data path: {}".format(root_path))
|
print("\n > DataLoader initialization")
|
||||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
print(" | > Data path: {}".format(root_path))
|
||||||
if use_phonemes:
|
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||||
print(" | > phoneme language: {}".format(phoneme_language))
|
if use_phonemes:
|
||||||
print(" | > Cached dataset: {}".format(self.cached))
|
print(" | > phoneme language: {}".format(phoneme_language))
|
||||||
print(" | > Number of instances : {}".format(len(self.items)))
|
print(" | > Cached dataset: {}".format(self.cached))
|
||||||
|
print(" | > Number of instances : {}".format(len(self.items)))
|
||||||
self.sort_items()
|
self.sort_items()
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
|
@ -125,11 +128,7 @@ class MyDataset(Dataset):
|
||||||
def sort_items(self):
|
def sort_items(self):
|
||||||
r"""Sort instances based on text length in ascending order"""
|
r"""Sort instances based on text length in ascending order"""
|
||||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||||
|
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
|
||||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
|
||||||
|
|
||||||
idxs = np.argsort(lengths)
|
idxs = np.argsort(lengths)
|
||||||
new_items = []
|
new_items = []
|
||||||
ignored = []
|
ignored = []
|
||||||
|
@ -139,11 +138,8 @@ class MyDataset(Dataset):
|
||||||
ignored.append(idx)
|
ignored.append(idx)
|
||||||
else:
|
else:
|
||||||
new_items.append(self.items[idx])
|
new_items.append(self.items[idx])
|
||||||
print(" | > {} instances are ignored ({})".format(
|
|
||||||
len(ignored), self.min_seq_len))
|
|
||||||
# shuffle batch groups
|
# shuffle batch groups
|
||||||
if self.batch_group_size > 0:
|
if self.batch_group_size > 0:
|
||||||
print(" | > Batch group shuffling is active.")
|
|
||||||
for i in range(len(new_items) // self.batch_group_size):
|
for i in range(len(new_items) // self.batch_group_size):
|
||||||
offset = i * self.batch_group_size
|
offset = i * self.batch_group_size
|
||||||
end_offset = offset + self.batch_group_size
|
end_offset = offset + self.batch_group_size
|
||||||
|
@ -152,6 +148,14 @@ class MyDataset(Dataset):
|
||||||
new_items[offset : end_offset] = temp_items
|
new_items[offset : end_offset] = temp_items
|
||||||
self.items = new_items
|
self.items = new_items
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||||
|
print(" | > Num. instances discarded by max-min seq limits: {}".format(
|
||||||
|
len(ignored), self.min_seq_len))
|
||||||
|
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
Encouraged, he started with a minute a day.
|
|
||||||
His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe.
|
|
||||||
Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning.
|
|
||||||
If he decided to watch TV he really watched it.
|
|
||||||
Often we try to bring about change through sheer effort and we put all of our energy into a new initiative.
|
|
|
@ -21,7 +21,6 @@ class Tacotron(nn.Module):
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
num_chars, embedding_dim, padding_idx=padding_idx)
|
num_chars, embedding_dim, padding_idx=padding_idx)
|
||||||
print(" | > Number of characters : {}".format(num_chars))
|
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
|
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
|
||||||
|
|
408
train.py
408
train.py
|
@ -1,38 +1,44 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import shutil
|
|
||||||
import torch
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from utils.generic_utils import (
|
|
||||||
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
|
||||||
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
|
||||||
get_commit_hash, sequence_mask, NoamLR)
|
|
||||||
from utils.text.symbols import symbols, phonemes
|
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from models.tacotron import Tacotron
|
|
||||||
from layers.losses import L1LossMasked
|
|
||||||
from datasets.TTSDataset import MyDataset
|
from datasets.TTSDataset import MyDataset
|
||||||
|
from layers.losses import L1LossMasked
|
||||||
|
from models.tacotron import Tacotron
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from utils.synthesis import synthesis
|
from utils.generic_utils import (
|
||||||
|
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||||
|
get_commit_hash, load_config, lr_decay, remove_experiment_folder,
|
||||||
|
save_best_model, save_checkpoint, sequence_mask, weight_decay)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
from utils.synthesis import synthesis
|
||||||
|
from utils.text.symbols import phonemes, symbols
|
||||||
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from distribute import init_distributed, apply_gradient_allreduce, reduce_tensor
|
||||||
|
from distribute import DistributedSampler
|
||||||
|
|
||||||
torch.manual_seed(1)
|
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.manual_seed(54321)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
print(" > Using CUDA: ", use_cuda)
|
print(" > Using CUDA: ", use_cuda)
|
||||||
print(" > Number of GPUs: ", torch.cuda.device_count())
|
print(" > Number of GPUs: ", num_gpus)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(is_val=False):
|
def setup_loader(is_val=False, verbose=False):
|
||||||
global ap
|
global ap
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
|
@ -44,38 +50,44 @@ def setup_loader(is_val=False):
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||||
min_seq_len=0 if is_val else c.min_seq_len,
|
min_seq_len=0 if is_val else c.min_seq_len,
|
||||||
max_seq_len=float("inf") if is_val else c.max_seq_len,
|
max_seq_len=float("inf") if is_val else c.max_seq_len,
|
||||||
cached=False if c.dataset != "tts_cache" else True,
|
cached=False if c.dataset != "tts_cache" else True,
|
||||||
phoneme_cache_path=c.phoneme_cache_path,
|
phoneme_cache_path=c.phoneme_cache_path,
|
||||||
use_phonemes=c.use_phonemes,
|
use_phonemes=c.use_phonemes,
|
||||||
phoneme_language=c.phoneme_language
|
phoneme_language=c.phoneme_language,
|
||||||
)
|
verbose=verbose)
|
||||||
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
sampler=sampler,
|
||||||
|
num_workers=c.num_val_loader_workers
|
||||||
|
if is_val else c.num_loader_workers,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
scheduler, ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(is_val=False)
|
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
avg_step_time = 0
|
avg_step_time = 0
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
n_priority_freq = int(
|
n_priority_freq = int(
|
||||||
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
if num_gpus > 0:
|
||||||
|
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||||
|
else:
|
||||||
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -116,12 +128,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
mask = sequence_mask(text_lengths)
|
mask = sequence_mask(text_lengths)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
if use_cuda:
|
mel_output, linear_output, alignments, stop_tokens = model(
|
||||||
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
text_input, mel_input, mask)
|
||||||
model, (text_input, mel_input, mask))
|
|
||||||
else:
|
|
||||||
mel_output, linear_output, alignments, stop_tokens = model(
|
|
||||||
text_input, mel_input, mask)
|
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
@ -134,29 +142,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
# custom weight decay
|
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
||||||
for group in optimizer.param_groups:
|
grad_norm, _ = check_update(model, 1.0)
|
||||||
for param in group['params']:
|
|
||||||
current_lr = group['lr']
|
|
||||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
|
||||||
grad_norm, skip_flag = check_update(model, 1)
|
|
||||||
if skip_flag:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
print(" | > Iteration skipped!!", flush=True)
|
|
||||||
continue
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
# custom weight decay
|
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
||||||
for group in optimizer_st.param_groups:
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
for param in group['params']:
|
|
||||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
|
||||||
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
|
||||||
if skip_flag:
|
|
||||||
optimizer_st.zero_grad()
|
|
||||||
print(" | > Iteration skipped fro stopnet!!")
|
|
||||||
continue
|
|
||||||
optimizer_st.step()
|
optimizer_st.step()
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -164,49 +157,62 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
|
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}"
|
||||||
num_iter, batch_n_iter, current_step, loss.item(),
|
.format(num_iter, batch_n_iter, current_step, loss.item(),
|
||||||
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
grad_norm, grad_norm_st, avg_text_length,
|
||||||
|
avg_spec_length, step_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += float(linear_loss.item())
|
# aggregate losses from processes
|
||||||
avg_mel_loss += float(mel_loss.item())
|
if num_gpus > 1:
|
||||||
avg_stop_loss += stop_loss.item()
|
linear_loss = reduce_tensor(linear_loss.data, num_gpus)
|
||||||
avg_step_time += step_time
|
mel_loss = reduce_tensor(mel_loss.data, num_gpus)
|
||||||
|
loss = reduce_tensor(loss.data, num_gpus)
|
||||||
|
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
if args.rank == 0:
|
||||||
iter_stats = {"loss_posnet": linear_loss.item(),
|
avg_linear_loss += float(linear_loss.item())
|
||||||
"loss_decoder": mel_loss.item(),
|
avg_mel_loss += float(mel_loss.item())
|
||||||
"lr": current_lr,
|
avg_stop_loss += stop_loss.item()
|
||||||
"grad_norm": grad_norm,
|
avg_step_time += step_time
|
||||||
"grad_norm_st": grad_norm_st,
|
|
||||||
"step_time": step_time}
|
|
||||||
tb_logger.tb_train_iter_stats(current_step, iter_stats)
|
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
# Plot Training Iter Stats
|
||||||
if c.checkpoint:
|
iter_stats = {
|
||||||
# save model
|
"loss_posnet": linear_loss.item(),
|
||||||
save_checkpoint(model, optimizer, optimizer_st,
|
"loss_decoder": mel_loss.item(),
|
||||||
linear_loss.item(), OUT_PATH, current_step,
|
"lr": current_lr,
|
||||||
epoch)
|
"grad_norm": grad_norm,
|
||||||
|
"grad_norm_st": grad_norm_st,
|
||||||
|
"step_time": step_time
|
||||||
|
}
|
||||||
|
tb_logger.tb_train_iter_stats(current_step, iter_stats)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
if current_step % c.save_step == 0:
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
if c.checkpoint:
|
||||||
gt_spec = linear_input[0].data.cpu().numpy()
|
# save model
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
save_checkpoint(model, optimizer, optimizer_st,
|
||||||
|
linear_loss.item(), OUT_PATH, current_step,
|
||||||
|
epoch)
|
||||||
|
|
||||||
figures = {"prediction": plot_spectrogram(const_spec, ap),
|
# Diagnostic visualizations
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
"alignment": plot_alignment(align_img)}
|
gt_spec = linear_input[0].data.cpu().numpy()
|
||||||
tb_logger.tb_train_figures(current_step, figures)
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
# Sample audio
|
figures = {
|
||||||
tb_logger.tb_train_audios(current_step,
|
"prediction": plot_spectrogram(const_spec, ap),
|
||||||
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
c.audio["sample_rate"])
|
"alignment": plot_alignment(align_img)
|
||||||
|
}
|
||||||
|
tb_logger.tb_train_figures(current_step, figures)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
tb_logger.tb_train_audios(
|
||||||
|
current_step, {'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
||||||
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
|
@ -216,7 +222,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(
|
print(
|
||||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||||
|
@ -224,25 +230,29 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
avg_stop_loss, epoch_time, avg_step_time),
|
avg_stop_loss, epoch_time, avg_step_time),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Epoch Stats
|
||||||
epoch_stats = {"loss_postnet": avg_linear_loss,
|
if args.rank == 0:
|
||||||
"loss_decoder": avg_mel_loss,
|
# Plot Training Epoch Stats
|
||||||
"stop_loss": avg_stop_loss,
|
epoch_stats = {
|
||||||
"epoch_time": epoch_time}
|
"loss_postnet": avg_linear_loss,
|
||||||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
"loss_decoder": avg_mel_loss,
|
||||||
if c.tb_model_param_stats:
|
"stop_loss": avg_stop_loss,
|
||||||
tb_logger.tb_model_weights(model, current_step)
|
"epoch_time": epoch_time
|
||||||
|
}
|
||||||
|
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
||||||
|
if c.tb_model_param_stats:
|
||||||
|
tb_logger.tb_model_weights(model, current_step)
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, current_step):
|
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
data_loader = setup_loader(is_val=True)
|
data_loader = setup_loader(is_val=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Validation")
|
print("\n > Validation")
|
||||||
test_sentences = [
|
test_sentences = [
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
|
@ -296,74 +306,95 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
" | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||||
"StopLoss: {:.5f} ".format(loss.item(),
|
"StopLoss: {:.5f} ".format(loss.item(),
|
||||||
linear_loss.item(),
|
linear_loss.item(),
|
||||||
mel_loss.item(),
|
mel_loss.item(),
|
||||||
stop_loss.item()),
|
stop_loss.item()),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
|
# aggregate losses from processes
|
||||||
|
if num_gpus > 1:
|
||||||
|
linear_loss = reduce_tensor(linear_loss.data, num_gpus)
|
||||||
|
mel_loss = reduce_tensor(mel_loss.data, num_gpus)
|
||||||
|
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||||
|
|
||||||
avg_linear_loss += float(linear_loss.item())
|
avg_linear_loss += float(linear_loss.item())
|
||||||
avg_mel_loss += float(mel_loss.item())
|
avg_mel_loss += float(mel_loss.item())
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
if args.rank == 0:
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
eval_figures = {"prediction": plot_spectrogram(const_spec, ap),
|
eval_figures = {
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"prediction": plot_spectrogram(const_spec, ap),
|
||||||
"alignment": plot_alignment(align_img)}
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
tb_logger.tb_eval_figures(current_step, eval_figures)
|
"alignment": plot_alignment(align_img)
|
||||||
|
}
|
||||||
|
tb_logger.tb_eval_figures(current_step, eval_figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_eval_audios(current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)}, c.audio["sample_rate"])
|
tb_logger.tb_eval_audios(
|
||||||
|
current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)},
|
||||||
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
# compute average losses
|
# compute average losses
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_stop_loss /= (num_iter + 1)
|
||||||
|
|
||||||
# Plot Validation Stats
|
# Plot Validation Stats
|
||||||
epoch_stats = {"loss_postnet": avg_linear_loss,
|
epoch_stats = {
|
||||||
"loss_decoder": avg_mel_loss,
|
"loss_postnet": avg_linear_loss,
|
||||||
"stop_loss": avg_stop_loss}
|
"loss_decoder": avg_mel_loss,
|
||||||
tb_logger.tb_eval_stats(current_step, epoch_stats)
|
"stop_loss": avg_stop_loss
|
||||||
|
}
|
||||||
|
tb_logger.tb_eval_stats(current_step, epoch_stats)
|
||||||
|
|
||||||
# test sentences
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
test_audios = {}
|
# test sentences
|
||||||
test_figures = {}
|
test_audios = {}
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
test_figures = {}
|
||||||
try:
|
print(" | > Synthesizing test sentences")
|
||||||
wav, alignment, linear_spec, _, stop_tokens = synthesis(
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
model, test_sentence, c, use_cuda, ap)
|
try:
|
||||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
wav, alignment, linear_spec, _, stop_tokens = synthesis(
|
||||||
os.makedirs(file_path, exist_ok=True)
|
model, test_sentence, c, use_cuda, ap)
|
||||||
file_path = os.path.join(file_path,
|
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
||||||
"TestSentence_{}.wav".format(idx))
|
os.makedirs(file_path, exist_ok=True)
|
||||||
ap.save_wav(wav, file_path)
|
file_path = os.path.join(file_path,
|
||||||
test_audios['{}-audio'.format(idx)] = wav
|
"TestSentence_{}.wav".format(idx))
|
||||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(linear_spec, ap)
|
ap.save_wav(wav, file_path)
|
||||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment)
|
test_audios['{}-audio'.format(idx)] = wav
|
||||||
except:
|
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
linear_spec, ap)
|
||||||
traceback.print_exc()
|
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||||
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
|
alignment)
|
||||||
tb_logger.tb_test_figures(current_step, test_figures)
|
except:
|
||||||
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
|
traceback.print_exc()
|
||||||
|
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
|
||||||
|
tb_logger.tb_test_figures(current_step, test_figures)
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
# DISTRUBUTED
|
||||||
|
if num_gpus > 1:
|
||||||
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
|
c.distributed["backend"], c.distributed["url"])
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
model = Tacotron(num_chars=num_chars,
|
model = Tacotron(
|
||||||
embedding_dim=c.embedding_size,
|
num_chars=num_chars,
|
||||||
linear_dim=ap.num_freq,
|
embedding_dim=c.embedding_size,
|
||||||
mel_dim=ap.num_mels,
|
linear_dim=ap.num_freq,
|
||||||
r=c.r,
|
mel_dim=ap.num_mels,
|
||||||
memory_size=c.memory_size)
|
r=c.r,
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
memory_size=c.memory_size)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
optimizer_st = optim.Adam(
|
optimizer_st = optim.Adam(
|
||||||
|
@ -385,24 +416,26 @@ def main(args):
|
||||||
# 1. filter out unnecessary keys
|
# 1. filter out unnecessary keys
|
||||||
pretrained_dict = {
|
pretrained_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in checkpoint['model'].items() if k in model_dict
|
for k, v in checkpoint['model'].items() if k in model_dict
|
||||||
}
|
}
|
||||||
# 2. filter out different size layers
|
# 2. filter out different size layers
|
||||||
pretrained_dict = {
|
pretrained_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()
|
for k, v in pretrained_dict.items()
|
||||||
|
if v.numel() == model_dict[k].numel()
|
||||||
}
|
}
|
||||||
# 3. overwrite entries in the existing state dict
|
# 3. overwrite entries in the existing state dict
|
||||||
model_dict.update(pretrained_dict)
|
model_dict.update(pretrained_dict)
|
||||||
# 4. load the new state dict
|
# 4. load the new state dict
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
print(" | > {} / {} layers are initialized".format(
|
||||||
|
len(pretrained_dict), len(model_dict)))
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
print(
|
print(
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['epoch']
|
start_epoch = checkpoint['epoch']
|
||||||
|
@ -410,12 +443,15 @@ def main(args):
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training", flush=True)
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if num_gpus > 1:
|
||||||
|
model = apply_gradient_allreduce(model)
|
||||||
|
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
scheduler = NoamLR(
|
scheduler = NoamLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -425,22 +461,18 @@ def main(args):
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
if not os.path.exists(CHECKPOINT_PATH):
|
|
||||||
os.mkdir(CHECKPOINT_PATH)
|
|
||||||
|
|
||||||
if 'best_loss' not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
optimizer, optimizer_st,
|
optimizer, optimizer_st, scheduler,
|
||||||
scheduler, ap, epoch)
|
ap, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, ap,
|
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
||||||
current_step)
|
|
||||||
print(
|
print(
|
||||||
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
train_loss, val_loss),
|
train_loss, val_loss),
|
||||||
flush=True)
|
flush=True)
|
||||||
target_loss = train_loss
|
target_loss = train_loss
|
||||||
|
@ -468,31 +500,59 @@ if __name__ == '__main__':
|
||||||
default=False,
|
default=False,
|
||||||
help='Do not verify commit integrity to run training.')
|
help='Do not verify commit integrity to run training.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.')
|
'--data_path',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='Defines the data path. It overwrites config.json.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_path',
|
||||||
|
type=str,
|
||||||
|
help='path for training outputs.',
|
||||||
|
default='')
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
parser.add_argument(
|
||||||
|
'--rank',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='DISTRIBUTED: process rank for distributed training.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--group_id',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='DISTRIBUTED: process group id.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
|
||||||
|
|
||||||
if args.data_path != '':
|
if args.data_path != '':
|
||||||
c.data_path = args.data_path
|
c.data_path = args.data_path
|
||||||
|
|
||||||
# setup tensorboard
|
if args.output_path == '':
|
||||||
LOG_DIR = OUT_PATH
|
OUT_PATH = os.path.join(_, c.output_path)
|
||||||
tb_logger = Logger(LOG_DIR)
|
else:
|
||||||
|
OUT_PATH = args.output_path
|
||||||
|
|
||||||
|
if args.group_id == '':
|
||||||
|
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
||||||
|
|
||||||
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||||
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH,
|
||||||
|
'config.json'))
|
||||||
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
|
||||||
|
if args.rank==0:
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb_logger = Logger(LOG_DIR)
|
||||||
|
|
||||||
# Conditional imports
|
# Conditional imports
|
||||||
preprocessor = importlib.import_module('datasets.preprocess')
|
preprocessor = importlib.import_module('datasets.preprocess')
|
||||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||||
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
|
||||||
|
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
|
@ -50,10 +50,9 @@ class AudioProcessor(object):
|
||||||
self.clip_norm = clip_norm
|
self.clip_norm = clip_norm
|
||||||
self.do_trim_silence = do_trim_silence
|
self.do_trim_silence = do_trim_silence
|
||||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||||
print(" | > Audio Processor attributes.")
|
|
||||||
members = vars(self)
|
members = vars(self)
|
||||||
for key, value in members.items():
|
for key, value in members.items():
|
||||||
print(" | > {}:{}".format(key, value))
|
print(" | > {}:{}".format(key, value))
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||||
|
@ -118,8 +117,6 @@ class AudioProcessor(object):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||||
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
|
||||||
n_fft, hop_length, win_length))
|
|
||||||
return n_fft, hop_length, win_length
|
return n_fft, hop_length, win_length
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
def _amp_to_db(self, x):
|
||||||
|
|
|
@ -1,151 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import librosa
|
|
||||||
import pickle
|
|
||||||
import copy
|
|
||||||
import numpy as np
|
|
||||||
from scipy import signal
|
|
||||||
import lws
|
|
||||||
|
|
||||||
_mel_basis = None
|
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sample_rate,
|
|
||||||
num_mels,
|
|
||||||
min_level_db,
|
|
||||||
frame_shift_ms,
|
|
||||||
frame_length_ms,
|
|
||||||
ref_level_db,
|
|
||||||
num_freq,
|
|
||||||
power,
|
|
||||||
preemphasis,
|
|
||||||
min_mel_freq,
|
|
||||||
max_mel_freq,
|
|
||||||
griffin_lim_iters=None,
|
|
||||||
):
|
|
||||||
print(" > Setting up Audio Processor...")
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.num_mels = num_mels
|
|
||||||
self.min_level_db = min_level_db
|
|
||||||
self.frame_shift_ms = frame_shift_ms
|
|
||||||
self.frame_length_ms = frame_length_ms
|
|
||||||
self.ref_level_db = ref_level_db
|
|
||||||
self.num_freq = num_freq
|
|
||||||
self.power = power
|
|
||||||
self.min_mel_freq = min_mel_freq
|
|
||||||
self.max_mel_freq = max_mel_freq
|
|
||||||
self.griffin_lim_iters = griffin_lim_iters
|
|
||||||
self.preemphasis = preemphasis
|
|
||||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
|
||||||
if preemphasis == 0:
|
|
||||||
print(" | > Preemphasis is deactive.")
|
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
|
||||||
librosa.output.write_wav(
|
|
||||||
path, wav.astype(np.int16), self.sample_rate)
|
|
||||||
|
|
||||||
def _stft_parameters(self, ):
|
|
||||||
n_fft = int((self.num_freq - 1) * 2)
|
|
||||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
|
||||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
|
||||||
if n_fft % hop_length != 0:
|
|
||||||
hop_length = n_fft / 8
|
|
||||||
print(" | > hop_length is set to default ({}).".format(hop_length))
|
|
||||||
if n_fft % win_length != 0:
|
|
||||||
win_length = n_fft / 2
|
|
||||||
print(" | > win_length is set to default ({}).".format(win_length))
|
|
||||||
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
|
||||||
n_fft, hop_length, win_length))
|
|
||||||
return int(n_fft), int(hop_length), int(win_length)
|
|
||||||
|
|
||||||
def _lws_processor(self):
|
|
||||||
try:
|
|
||||||
return lws.lws(
|
|
||||||
self.win_length,
|
|
||||||
self.hop_length,
|
|
||||||
fftsize=self.n_fft,
|
|
||||||
mode="speech")
|
|
||||||
except:
|
|
||||||
raise RuntimeError(
|
|
||||||
" !! WindowLength({}) is not multiple of HopLength({}).".
|
|
||||||
format(self.win_length, self.hop_length))
|
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
|
||||||
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
|
||||||
return 20 * np.log10(np.maximum(min_level, x))
|
|
||||||
|
|
||||||
def _db_to_amp(self, x):
|
|
||||||
return np.power(10.0, x * 0.05)
|
|
||||||
|
|
||||||
def _normalize(self, S):
|
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
|
||||||
|
|
||||||
def _denormalize(self, S):
|
|
||||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
|
||||||
|
|
||||||
def apply_preemphasis(self, x):
|
|
||||||
if self.preemphasis == 0:
|
|
||||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
|
||||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
|
||||||
|
|
||||||
def apply_inv_preemphasis(self, x):
|
|
||||||
if self.preemphasis == 0:
|
|
||||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
|
||||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
|
||||||
|
|
||||||
def spectrogram(self, y):
|
|
||||||
f = open(os.devnull, 'w')
|
|
||||||
old_out = sys.stdout
|
|
||||||
sys.stdout = f
|
|
||||||
if self.preemphasis:
|
|
||||||
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
|
|
||||||
else:
|
|
||||||
D = self._lws_processor().stft(y).T
|
|
||||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
|
||||||
sys.stdout = old_out
|
|
||||||
return self._normalize(S)
|
|
||||||
|
|
||||||
def inv_spectrogram(self, spectrogram):
|
|
||||||
'''Converts spectrogram to waveform using librosa'''
|
|
||||||
f = open(os.devnull, 'w')
|
|
||||||
old_out = sys.stdout
|
|
||||||
sys.stdout = f
|
|
||||||
S = self._denormalize(spectrogram)
|
|
||||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
|
||||||
processor = self._lws_processor()
|
|
||||||
D = processor.run_lws(S.astype(np.float64).T**self.power)
|
|
||||||
y = processor.istft(D).astype(np.float32)
|
|
||||||
# Reconstruct phase
|
|
||||||
sys.stdout = old_out
|
|
||||||
if self.preemphasis:
|
|
||||||
return self.apply_inv_preemphasis(y)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def _linear_to_mel(self, spectrogram):
|
|
||||||
global _mel_basis
|
|
||||||
if _mel_basis is None:
|
|
||||||
_mel_basis = self._build_mel_basis()
|
|
||||||
return np.dot(_mel_basis, spectrogram)
|
|
||||||
|
|
||||||
def _build_mel_basis(self, ):
|
|
||||||
return librosa.filters.mel(
|
|
||||||
self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
|
||||||
|
|
||||||
|
|
||||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
|
||||||
|
|
||||||
def melspectrogram(self, y):
|
|
||||||
f = open(os.devnull, 'w')
|
|
||||||
old_out = sys.stdout
|
|
||||||
sys.stdout = f
|
|
||||||
if self.preemphasis:
|
|
||||||
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
|
|
||||||
else:
|
|
||||||
D = self._lws_processor().stft(y).T
|
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
|
||||||
sys.stdout = old_out
|
|
||||||
return self._normalize(S)
|
|
|
@ -123,7 +123,7 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
best_loss = model_loss
|
best_loss = model_loss
|
||||||
bestmodel_path = 'best_model.pth.tar'
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
print(" | > Best model saving with loss {0:.5f} : {1:}".format(
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
|
||||||
model_loss, bestmodel_path))
|
model_loss, bestmodel_path))
|
||||||
torch.save(state, bestmodel_path)
|
torch.save(state, bestmodel_path)
|
||||||
return best_loss
|
return best_loss
|
||||||
|
@ -148,6 +148,17 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decay(optimizer, wd):
|
||||||
|
"""
|
||||||
|
Custom weight decay operation, not effecting grad values.
|
||||||
|
"""
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
current_lr = group['lr']
|
||||||
|
param.data = param.data.add(-wd * group['lr'], param.data)
|
||||||
|
return optimizer, current_lr
|
||||||
|
|
||||||
|
|
||||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||||
self.warmup_steps = float(warmup_steps)
|
self.warmup_steps = float(warmup_steps)
|
||||||
|
|
Loading…
Reference in New Issue