initial wavegrad layers model and trainig script

This commit is contained in:
erogol 2020-10-16 16:35:25 +02:00
parent ac57eea928
commit e02cd6a220
5 changed files with 987 additions and 0 deletions

490
TTS/bin/train_wavegrad.py Normal file
View File

@ -0,0 +1,490 @@
import argparse
import glob
import os
import sys
import time
import traceback
from inspect import signature
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
setup_generator)
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveGradDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=True,
use_noise_augment=False,
use_cache=c.use_cache,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.batch_size,
shuffle=False if num_gpus > 1 else True,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
# return a whole audio segment
m, y = data
if use_cuda:
m = m.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
return m, y
def train(model, criterion, optimizer,
scheduler, ap, global_step, epoch, amp):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, y = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
y_noisy, noise_scale = model.module.compute_noisy_x(y)
else:
y_noisy, noise_scale = model.compute_noisy_x(y)
# forward pass
y_hat = model(y_noisy, m, noise_scale)
# compute losses
loss = criterion(y_noisy, y_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
# backward pass with loss scaling
optimizer.zero_grad()
if amp is not None:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
if c.clip_grad > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# schedule update
if scheduler is not None:
scheduler.step()
# disconnect loss values
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, int):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
# epoch/step timing
step_time = time.time() - start_time
epoch_time += step_time
# get current learning rates
current_lr = list(optimizer.param_groups)[0]['lr']
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr": current_lr,
"grad_norm": grad_norm
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(y_hat[0], y[0], ap, global_step, 'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Training Epoch Stats
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
if args.rank == 0:
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats
# if c.tb_model_param_stats:
# tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
end_time = time.time()
c_logger.print_eval_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, y = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
y_noisy, noise_scale = model.module.compute_noisy_x(y)
else:
y_noisy, noise_scale = model.compute_noisy_x(y)
# forward pass
y_hat = model(y_noisy, m, noise_scale)
# compute losses
loss = criterion(y_noisy, y_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
step_time = time.time() - start_time
epoch_time += step_time
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# compute spectrograms
figures = plot_results(y_hat, y, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global train_data, eval_data
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
# setup audio processor
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
# setup models
model = setup_generator(c)
# setup optimizers
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
# DISTRIBUTED
if c.apex_amp_level:
# pylint: disable=import-outside-toplevel
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# schedulers
scheduler = None
if 'lr_scheduler' in c:
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
# setup criterion
criterion = torch.nn.L1Loss().cuda()
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
print(" > Restoring LR Scheduler...")
scheduler.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler.optimizer = optimizer
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
# DISTRUBUTED
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
# reset lr if not countinuining training.
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP(model)
num_params = count_parameters(model)
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
if 'best_loss' not in locals():
best_loss = float('inf')
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model(target_loss,
best_loss,
model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
amp_state_dict=amp.state_dict() if amp else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help=
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument('--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(
args.continue_path +
"/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.apex_amp_level:
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path,
os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)

View File

@ -0,0 +1,103 @@
{
"run_name": "wavegrad-libritts",
"run_description": "wavegrad libritts",
"audio":{
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
// Silence trimming
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
// MelSpectrogram parameters
"num_mels": 80, // size of the mel spec frame.
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
// Normalization parameters
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
"min_level_db": -100, // lower bound for normalization
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": "/home/erogol/Data/libritts/LibriTTS/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
// DISTRIBUTED TRAINING
"apex_amp_level": "O1", // amp optimization level. "O1" is currentl supported.
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"target_loss": "avg_wavegrad_loss", // loss value to pick the best model to save after each epoch
// MODEL PARAMETERS
"generator_model": "wavegrad",
"model_params":{
"x_conv_channels":32,
"c_conv_channels":768,
"ublock_out_channels": [768, 512, 512, 256, 128],
"dblock_out_channels": [128, 128, 256, 512],
"upsample_factors": [4, 4, 4, 2, 2],
"upsample_dilations": [
[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 4, 8],
[1, 2, 4, 8],
[1, 2, 4, 8]]
},
// DATASET
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // root data path. It finds all wav files recursively from there.
"feature_path": null, // if you use precomputed features
"seq_len": 6144, // 24 * hop_length
"pad_short": 2000, // additional padding for short wavs
"conv_pad": 0, // additional padding against convolutions applied to spectrograms
"use_noise_augment": false, // add noise to the audio signal for augmentation
"use_cache": true, // use in memory cache to keep the computed features. This might cause OOM.
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING
"batch_size": 64, // Batch size for training.
// VALIDATION
"run_eval": true, // enable/disable evaluation run
// OPTIMIZER
"epochs": 10000, // total number of epochs to train.
"clip_grad": 1, // Generator gradient clipping threshold. Apply gradient clipping if > 0
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"lr_scheduler_params": {
"gamma": 0.5,
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
},
"lr": 1e-4, // Initial learning rate. If Noam decay is active, maximum learning rate.
// TENSORBOARD and LOGGING
"print_step": 25, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
// DATA LOADING
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"eval_split_size": 10,
// PATHS
"output_path": "/home/erogol/Models/LJSpeech/"
}

View File

@ -0,0 +1,113 @@
import os
import glob
import torch
import random
import numpy as np
from torch.utils.data import Dataset
from multiprocessing import Manager
class WaveGradDataset(Dataset):
"""
WaveGrad Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False):
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
self.conv_pad = conv_pad
self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# cache acoustic features
if use_cache:
self.create_feature_cache()
def create_feature_cache(self):
self.manager = Manager()
self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))]
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
def __len__(self):
return len(self.item_list)
def __getitem__(self, idx):
item = self.load_item(idx)
return item
def load_item(self, idx):
""" load (audio, feat) couple """
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
if len(audio) < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
mel = self.ap.melspectrogram(audio)
else:
# load precomputed features
wavpath, feat_path = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[:mel.shape[-1] * self.hop_len]
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
if self.return_segments:
max_mel_start = mel.shape[1] - self.feat_frame_len
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.feat_frame_len
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hop_len
audio = audio[:, audio_start:audio_start +
self.seq_len]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)
return (mel, audio)

View File

@ -0,0 +1,150 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
class NoiseLevelEncoding(nn.Module):
"""Noise level encoding applying same
encoding vector to all time steps. It is
different than the original implementation."""
def __init__(self, n_channels):
super().__init__()
self.n_channels = n_channels
self.length = n_channels // 2
assert n_channels % 2 == 0
enc = self.init_encoding(self.length)
self.register_buffer('enc', enc)
def forward(self, x, noise_level):
"""
Shapes:
x: B x C x T
noise_level: B
"""
return (x + self.encoding(noise_level)[:, :, None])
def init_encoding(self, length):
div_by = torch.arange(length) / length
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
return enc
def encoding(self, noise_level):
encoding = noise_level.unsqueeze(1) * self.enc
encoding = torch.cat(
[torch.sin(encoding), torch.cos(encoding)], dim=-1)
return encoding
class FiLM(nn.Module):
"""Feature-wise Linear Modulation. It combines information from
both noisy waveform and input mel-spectrogram. The FiLM module
produces both scale and bias vectors given inputs, which are
used in a UBlock for feature-wise affine transformation."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoding = NoiseLevelEncoding(in_channels)
self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1)
self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1)
self._init_parameters()
def _init_parameters(self):
nn.init.orthogonal_(self.conv_in.weight)
nn.init.orthogonal_(self.conv_out.weight)
def forward(self, x, noise_scale):
x = self.conv_in(x)
x = F.leaky_relu(x, 0.2)
x = self.encoding(x, noise_scale)
shift, scale = torch.chunk(self.conv_out(x), 2, dim=1)
return shift, scale
@torch.jit.script
def shif_and_scale(x, scale, shift):
o = shift + scale * x
return o
class UBlock(nn.Module):
def __init__(self, in_channels, hid_channels, upsample_factor, dilations):
super().__init__()
assert len(dilations) == 4
self.upsample_factor = upsample_factor
self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1)
self.main_block1 = nn.ModuleList([
nn.Conv1d(in_channels,
hid_channels,
3,
dilation=dilations[0],
padding=dilations[0]),
nn.Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[1],
padding=dilations[1])
])
self.main_block2 = nn.ModuleList([
nn.Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[2],
padding=dilations[2]),
nn.Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[3],
padding=dilations[3])
])
def forward(self, x, shift, scale):
upsample_size = x.shape[-1] * self.upsample_factor
x = F.interpolate(x, size=upsample_size)
res = self.shortcut_conv(x)
o = F.leaky_relu(x, 0.2)
o = self.main_block1[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block1[1](o)
o = o + res
res = o
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[1](o)
o = o + res
return o
class DBlock(nn.Module):
def __init__(self, in_channels, hid_channels, downsample_factor):
super().__init__()
self.downsample_factor = downsample_factor
self.res_conv = nn.Conv1d(in_channels, hid_channels, 1)
self.main_convs = nn.ModuleList([
nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
])
def forward(self, x):
size = x.shape[-1] // self.downsample_factor
res = self.res_conv(x)
res = F.interpolate(res, size=size)
o = F.interpolate(x, size=size)
for layer in self.main_convs:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res

View File

@ -0,0 +1,131 @@
import numpy as np
import torch
from torch import nn
from ..layers.wavegrad import DBlock, FiLM, UBlock
class Wavegrad(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(self,
in_channels=80,
out_channels=1,
x_conv_channels=32,
c_conv_channels=768,
dblock_out_channels=[128, 128, 256, 512],
ublock_out_channels=[512, 512, 256, 128, 128],
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8],
[1, 2, 4, 8], [1, 2, 4, 8]]):
super().__init__()
assert len(upsample_factors) == len(upsample_dilations)
assert len(upsample_factors) == len(ublock_out_channels)
# inference time noise schedule params
self.S = 1000
beta, alpha, alpha_cum, noise_level = self._setup_noise_level()
self.register_buffer('beta', beta)
self.register_buffer('alpha', alpha)
self.register_buffer('alpha_cum', alpha_cum)
self.register_buffer('noise_level', noise_level)
# setup up-down sampling parameters
self.hop_length = np.prod(upsample_factors)
self.upsample_factors = upsample_factors
self.downsample_factors = upsample_factors[::-1][:-1]
### define DBlocks, FiLM layers ###
self.dblocks = nn.ModuleList([
nn.Conv1d(out_channels, x_conv_channels, 5, padding=2),
])
ic = x_conv_channels
self.films = nn.ModuleList([])
for oc, df in zip(dblock_out_channels, self.downsample_factors):
# print('dblock(', ic, ', ', oc, ', ', df, ")")
layer = DBlock(ic, oc, df)
self.dblocks.append(layer)
# print('film(', ic, ', ', oc,")")
layer = FiLM(ic, oc)
self.films.append(layer)
ic = oc
# last FiLM block
# print('film(', ic, ', ', dblock_out_channels[-1],")")
self.films.append(FiLM(ic, dblock_out_channels[-1]))
### define UBlocks ###
self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1)
self.ublocks = nn.ModuleList([])
ic = c_conv_channels
for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)):
# print('ublock(', ic, ', ', oc, ', ', uf, ")")
layer = UBlock(ic, oc, uf, upsample_dilations[idx])
self.ublocks.append(layer)
ic = oc
# define last layer
# print(ic, 'last_conv--', out_channels)
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
def _setup_noise_level(self, noise_schedule=None):
"""compute noise schedule parameters"""
if noise_schedule is None:
beta = np.linspace(1e-6, 0.01, self.S)
else:
beta = noise_schedule
alpha = 1 - beta
alpha_cum = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0)
beta = torch.from_numpy(beta)
alpha = torch.from_numpy(alpha)
alpha_cum = torch.from_numpy(alpha_cum)
noise_level = torch.from_numpy(noise_level.astype(np.float32))
return beta, alpha, alpha_cum, noise_level
def compute_noisy_x(self, x):
B = x.shape[0]
if len(x.shape) == 3:
x = x.squeeze(1)
s = torch.randint(1, self.S + 1, [B]).to(x).long()
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(x)
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
return noisy_x.unsqueeze(1), noise_scale[:, 0]
def forward(self, x, c, noise_scale):
assert len(c.shape) == 3 # B, C, T
assert len(x.shape) == 3 # B, 1, T
o = x
shift_and_scales = []
for film, dblock in zip(self.films, self.dblocks):
o = dblock(o)
shift_and_scales.append(film(o, noise_scale))
o = self.c_conv(c)
for ublock, (film_shift, film_scale) in zip(self.ublocks,
reversed(shift_and_scales)):
o = ublock(o, film_shift, film_scale)
o = self.last_conv(o)
return o
def inference(self, c):
with torch.no_grad():
x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c)
noise_scale = torch.from_numpy(
self.alpha_cum**0.5).float().unsqueeze(1).to(c)
for n in range(len(self.alpha) - 1, -1, -1):
c1 = 1 / self.alpha[n]**0.5
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
x = c1 * (x -
c2 * self.forward(x, c, noise_scale[n]).squeeze(1))
if n > 0:
noise = torch.randn_like(x)
sigma = ((1.0 - self.alpha_cum[n - 1]) /
(1.0 - self.alpha_cum[n]) * self.beta[n])**0.5
x += sigma * noise
x = torch.clamp(x, -1.0, 1.0)
return x