mirror of https://github.com/coqui-ai/TTS.git
wavegrad updates
This commit is contained in:
parent
a1582a0e12
commit
7bcdb7ac35
|
@ -50,22 +50,35 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=False if num_gpus > 1 else True,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# return a whole audio segment
|
||||
m, y = data
|
||||
m, x = data
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
return m, y
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def format_test_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
m = m.unsqueeze(0)
|
||||
x = x.unsqueeze(0)
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def train(model, criterion, optimizer,
|
||||
|
@ -81,26 +94,36 @@ def train(model, criterion, optimizer,
|
|||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# setup noise schedule
|
||||
noise_schedule = c['train_noise_schedule']
|
||||
if hasattr(model, 'module'):
|
||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
else:
|
||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, y = format_data(data)
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||
else:
|
||||
y_noisy, noise_scale = model.compute_noisy_x(y)
|
||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||
|
||||
# forward pass
|
||||
y_hat = model(y_noisy, m, noise_scale)
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(y_noisy, y_hat)
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||
|
||||
# backward pass with loss scaling
|
||||
|
@ -181,15 +204,6 @@ def train(model, criterion, optimizer,
|
|||
OUT_PATH,
|
||||
model_losses=loss_dict)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat[0], y[0], ap, global_step, 'train')
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -218,23 +232,23 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, y = format_data(data)
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
||||
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||
else:
|
||||
y_noisy, noise_scale = model.compute_noisy_x(y)
|
||||
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||
|
||||
|
||||
# forward pass
|
||||
y_hat = model(y_noisy, m, noise_scale)
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(y_noisy, y_hat)
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||
|
||||
|
||||
|
@ -261,14 +275,32 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
samples = data_loader.dataset.load_test_samples(1)
|
||||
m, x = format_test_data(samples[0])
|
||||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c['test_noise_schedule']
|
||||
if hasattr(model, 'module'):
|
||||
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
else:
|
||||
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
# compute voice
|
||||
x_pred = model.inference(m)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y, ap, global_step, 'eval')
|
||||
figures = plot_results(x_pred, x, ap, global_step, 'eval')
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
|
|
|
@ -92,8 +92,8 @@
|
|||
// DATASET
|
||||
"data_path": "/home/erogol/Data/MozillaMerged22050/wavs/",
|
||||
"feature_path": null,
|
||||
"seq_len": 16384,
|
||||
"pad_short": 2000,
|
||||
"seq_len": 6144,
|
||||
"pad_short": 500,
|
||||
"conv_pad": 0,
|
||||
"use_noise_augment": false,
|
||||
"use_cache": true,
|
||||
|
@ -102,6 +102,16 @@
|
|||
|
||||
// TRAINING
|
||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"train_noise_schedule":{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 1000
|
||||
},
|
||||
"test_noise_schedule":{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50
|
||||
}
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
|
|
|
@ -71,6 +71,16 @@
|
|||
|
||||
// TRAINING
|
||||
"batch_size": 64, // Batch size for training.
|
||||
"train_noise_schedule":{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 1000
|
||||
},
|
||||
"test_noise_schedule":{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50
|
||||
},
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true, // enable/disable evaluation run
|
||||
|
|
|
@ -62,6 +62,13 @@ class WaveGradDataset(Dataset):
|
|||
item = self.load_item(idx)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
samples = []
|
||||
for idx in range(num_samples):
|
||||
mel, audio = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
return samples
|
||||
|
||||
def load_item(self, idx):
|
||||
""" load (audio, feat) couple """
|
||||
if self.compute_feat:
|
||||
|
|
|
@ -4,6 +4,16 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.orthogonal_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class NoiseLevelEncoding(nn.Module):
|
||||
"""Noise level encoding applying same
|
||||
encoding vector to all time steps. It is
|
||||
|
@ -25,7 +35,8 @@ class NoiseLevelEncoding(nn.Module):
|
|||
"""
|
||||
return (x + self.encoding(noise_level)[:, :, None])
|
||||
|
||||
def init_encoding(self, length):
|
||||
@staticmethod
|
||||
def init_encoding(length):
|
||||
div_by = torch.arange(length) / length
|
||||
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
|
||||
return enc
|
||||
|
@ -46,8 +57,8 @@ class FiLM(nn.Module):
|
|||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.encoding = NoiseLevelEncoding(in_channels)
|
||||
self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1)
|
||||
self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
||||
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
|
||||
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
||||
self._init_parameters()
|
||||
|
||||
def _init_parameters(self):
|
||||
|
@ -74,26 +85,26 @@ class UBlock(nn.Module):
|
|||
assert len(dilations) == 4
|
||||
|
||||
self.upsample_factor = upsample_factor
|
||||
self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
||||
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
|
||||
self.main_block1 = nn.ModuleList([
|
||||
nn.Conv1d(in_channels,
|
||||
Conv1d(in_channels,
|
||||
hid_channels,
|
||||
3,
|
||||
dilation=dilations[0],
|
||||
padding=dilations[0]),
|
||||
nn.Conv1d(hid_channels,
|
||||
Conv1d(hid_channels,
|
||||
hid_channels,
|
||||
3,
|
||||
dilation=dilations[1],
|
||||
padding=dilations[1])
|
||||
])
|
||||
self.main_block2 = nn.ModuleList([
|
||||
nn.Conv1d(hid_channels,
|
||||
Conv1d(hid_channels,
|
||||
hid_channels,
|
||||
3,
|
||||
dilation=dilations[2],
|
||||
padding=dilations[2]),
|
||||
nn.Conv1d(hid_channels,
|
||||
Conv1d(hid_channels,
|
||||
hid_channels,
|
||||
3,
|
||||
dilation=dilations[3],
|
||||
|
@ -129,11 +140,11 @@ class DBlock(nn.Module):
|
|||
def __init__(self, in_channels, hid_channels, downsample_factor):
|
||||
super().__init__()
|
||||
self.downsample_factor = downsample_factor
|
||||
self.res_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
||||
self.res_conv = Conv1d(in_channels, hid_channels, 1)
|
||||
self.main_convs = nn.ModuleList([
|
||||
nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
||||
nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
||||
nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
||||
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
||||
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
||||
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -22,14 +22,6 @@ class Wavegrad(nn.Module):
|
|||
assert len(upsample_factors) == len(upsample_dilations)
|
||||
assert len(upsample_factors) == len(ublock_out_channels)
|
||||
|
||||
# inference time noise schedule params
|
||||
self.S = 1000
|
||||
beta, alpha, alpha_cum, noise_level = self._setup_noise_level()
|
||||
self.register_buffer('beta', beta)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('alpha_cum', alpha_cum)
|
||||
self.register_buffer('noise_level', noise_level)
|
||||
|
||||
# setup up-down sampling parameters
|
||||
self.hop_length = np.prod(upsample_factors)
|
||||
self.upsample_factors = upsample_factors
|
||||
|
@ -68,21 +60,23 @@ class Wavegrad(nn.Module):
|
|||
# print(ic, 'last_conv--', out_channels)
|
||||
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
||||
|
||||
def _setup_noise_level(self, noise_schedule=None):
|
||||
"""compute noise schedule parameters"""
|
||||
if noise_schedule is None:
|
||||
beta = np.linspace(1e-6, 0.01, self.S)
|
||||
else:
|
||||
beta = noise_schedule
|
||||
alpha = 1 - beta
|
||||
alpha_cum = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0)
|
||||
# inference time noise schedule params
|
||||
self.S = 1000
|
||||
self.init_noise_schedule(self.S)
|
||||
|
||||
beta = torch.from_numpy(beta)
|
||||
alpha = torch.from_numpy(alpha)
|
||||
alpha_cum = torch.from_numpy(alpha_cum)
|
||||
noise_level = torch.from_numpy(noise_level.astype(np.float32))
|
||||
return beta, alpha, alpha_cum, noise_level
|
||||
|
||||
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01):
|
||||
"""compute noise schedule parameters"""
|
||||
device = self.last_conv.weight.device
|
||||
beta = torch.linspace(min_val, max_val, num_iter).to(device)
|
||||
alpha = 1 - beta
|
||||
alpha_cum = alpha.cumprod(dim=0)
|
||||
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
|
||||
|
||||
self.register_buffer('beta', beta)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('alpha_cum', alpha_cum)
|
||||
self.register_buffer('noise_level', noise_level)
|
||||
|
||||
def compute_noisy_x(self, x):
|
||||
B = x.shape[0]
|
||||
|
@ -94,7 +88,7 @@ class Wavegrad(nn.Module):
|
|||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(x)
|
||||
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
||||
return noisy_x.unsqueeze(1), noise_scale[:, 0]
|
||||
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def forward(self, x, c, noise_scale):
|
||||
assert len(c.shape) == 3 # B, C, T
|
||||
|
@ -114,9 +108,8 @@ class Wavegrad(nn.Module):
|
|||
|
||||
def inference(self, c):
|
||||
with torch.no_grad():
|
||||
x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c)
|
||||
noise_scale = torch.from_numpy(
|
||||
self.alpha_cum**0.5).float().unsqueeze(1).to(c)
|
||||
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
|
||||
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
c1 = 1 / self.alpha[n]**0.5
|
||||
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
||||
|
|
Loading…
Reference in New Issue