mirror of https://github.com/coqui-ai/TTS.git
wavegrad updates
This commit is contained in:
parent
a1582a0e12
commit
7bcdb7ac35
|
@ -50,22 +50,35 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(dataset,
|
loader = DataLoader(dataset,
|
||||||
batch_size=c.batch_size,
|
batch_size=c.batch_size,
|
||||||
shuffle=False if num_gpus > 1 else True,
|
shuffle=num_gpus <= 1,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=c.num_val_loader_workers
|
num_workers=c.num_val_loader_workers
|
||||||
if is_val else c.num_loader_workers,
|
if is_val else c.num_loader_workers,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# return a whole audio segment
|
# return a whole audio segment
|
||||||
m, y = data
|
m, x = data
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
m = m.cuda(non_blocking=True)
|
m = m.cuda(non_blocking=True)
|
||||||
y = y.cuda(non_blocking=True)
|
x = x.cuda(non_blocking=True)
|
||||||
return m, y
|
return m, x
|
||||||
|
|
||||||
|
|
||||||
|
def format_test_data(data):
|
||||||
|
# return a whole audio segment
|
||||||
|
m, x = data
|
||||||
|
m = m.unsqueeze(0)
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
if use_cuda:
|
||||||
|
m = m.cuda(non_blocking=True)
|
||||||
|
x = x.cuda(non_blocking=True)
|
||||||
|
return m, x
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer,
|
def train(model, criterion, optimizer,
|
||||||
|
@ -81,26 +94,36 @@ def train(model, criterion, optimizer,
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
|
# setup noise schedule
|
||||||
|
noise_schedule = c['train_noise_schedule']
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||||
|
noise_schedule['min_val'],
|
||||||
|
noise_schedule['max_val'])
|
||||||
|
else:
|
||||||
|
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||||
|
noise_schedule['min_val'],
|
||||||
|
noise_schedule['max_val'])
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
m, y = format_data(data)
|
m, x = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# compute noisy input
|
# compute noisy input
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||||
else:
|
else:
|
||||||
y_noisy, noise_scale = model.compute_noisy_x(y)
|
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
y_hat = model(y_noisy, m, noise_scale)
|
noise_hat = model(x_noisy, m, noise_scale)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(y_noisy, y_hat)
|
loss = criterion(noise, noise_hat)
|
||||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
# backward pass with loss scaling
|
# backward pass with loss scaling
|
||||||
|
@ -181,15 +204,6 @@ def train(model, criterion, optimizer,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=loss_dict)
|
model_losses=loss_dict)
|
||||||
|
|
||||||
# compute spectrograms
|
|
||||||
figures = plot_results(y_hat[0], y[0], ap, global_step, 'train')
|
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
|
||||||
|
|
||||||
# Sample audio
|
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
|
||||||
tb_logger.tb_train_audios(global_step,
|
|
||||||
{'train/audio': sample_voice},
|
|
||||||
c.audio["sample_rate"])
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
|
@ -218,23 +232,23 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
m, y = format_data(data)
|
m, x = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# compute noisy input
|
# compute noisy input
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
|
||||||
else:
|
else:
|
||||||
y_noisy, noise_scale = model.compute_noisy_x(y)
|
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
|
||||||
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
y_hat = model(y_noisy, m, noise_scale)
|
noise_hat = model(x_noisy, m, noise_scale)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(y_noisy, y_hat)
|
loss = criterion(noise, noise_hat)
|
||||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,12 +275,30 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
|
samples = data_loader.dataset.load_test_samples(1)
|
||||||
|
m, x = format_test_data(samples[0])
|
||||||
|
|
||||||
|
# setup noise schedule and inference
|
||||||
|
noise_schedule = c['test_noise_schedule']
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
model.module.init_noise_schedule(noise_schedule['num_steps'],
|
||||||
|
noise_schedule['min_val'],
|
||||||
|
noise_schedule['max_val'])
|
||||||
|
# compute voice
|
||||||
|
x_pred = model.module.inference(m)
|
||||||
|
else:
|
||||||
|
model.init_noise_schedule(noise_schedule['num_steps'],
|
||||||
|
noise_schedule['min_val'],
|
||||||
|
noise_schedule['max_val'])
|
||||||
|
# compute voice
|
||||||
|
x_pred = model.inference(m)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = plot_results(y_hat, y, ap, global_step, 'eval')
|
figures = plot_results(x_pred, x, ap, global_step, 'eval')
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
||||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
|
|
|
@ -92,8 +92,8 @@
|
||||||
// DATASET
|
// DATASET
|
||||||
"data_path": "/home/erogol/Data/MozillaMerged22050/wavs/",
|
"data_path": "/home/erogol/Data/MozillaMerged22050/wavs/",
|
||||||
"feature_path": null,
|
"feature_path": null,
|
||||||
"seq_len": 16384,
|
"seq_len": 6144,
|
||||||
"pad_short": 2000,
|
"pad_short": 500,
|
||||||
"conv_pad": 0,
|
"conv_pad": 0,
|
||||||
"use_noise_augment": false,
|
"use_noise_augment": false,
|
||||||
"use_cache": true,
|
"use_cache": true,
|
||||||
|
@ -102,6 +102,16 @@
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
|
"train_noise_schedule":{
|
||||||
|
"min_val": 1e-6,
|
||||||
|
"max_val": 1e-2,
|
||||||
|
"num_steps": 1000
|
||||||
|
},
|
||||||
|
"test_noise_schedule":{
|
||||||
|
"min_val": 1e-6,
|
||||||
|
"max_val": 1e-2,
|
||||||
|
"num_steps": 50
|
||||||
|
}
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
|
|
@ -71,6 +71,16 @@
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 64, // Batch size for training.
|
"batch_size": 64, // Batch size for training.
|
||||||
|
"train_noise_schedule":{
|
||||||
|
"min_val": 1e-6,
|
||||||
|
"max_val": 1e-2,
|
||||||
|
"num_steps": 1000
|
||||||
|
},
|
||||||
|
"test_noise_schedule":{
|
||||||
|
"min_val": 1e-6,
|
||||||
|
"max_val": 1e-2,
|
||||||
|
"num_steps": 50
|
||||||
|
},
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true, // enable/disable evaluation run
|
"run_eval": true, // enable/disable evaluation run
|
||||||
|
|
|
@ -62,6 +62,13 @@ class WaveGradDataset(Dataset):
|
||||||
item = self.load_item(idx)
|
item = self.load_item(idx)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
def load_test_samples(self, num_samples):
|
||||||
|
samples = []
|
||||||
|
for idx in range(num_samples):
|
||||||
|
mel, audio = self.load_item(idx)
|
||||||
|
samples.append([mel, audio])
|
||||||
|
return samples
|
||||||
|
|
||||||
def load_item(self, idx):
|
def load_item(self, idx):
|
||||||
""" load (audio, feat) couple """
|
""" load (audio, feat) couple """
|
||||||
if self.compute_feat:
|
if self.compute_feat:
|
||||||
|
|
|
@ -4,6 +4,16 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.orthogonal_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
|
||||||
class NoiseLevelEncoding(nn.Module):
|
class NoiseLevelEncoding(nn.Module):
|
||||||
"""Noise level encoding applying same
|
"""Noise level encoding applying same
|
||||||
encoding vector to all time steps. It is
|
encoding vector to all time steps. It is
|
||||||
|
@ -25,7 +35,8 @@ class NoiseLevelEncoding(nn.Module):
|
||||||
"""
|
"""
|
||||||
return (x + self.encoding(noise_level)[:, :, None])
|
return (x + self.encoding(noise_level)[:, :, None])
|
||||||
|
|
||||||
def init_encoding(self, length):
|
@staticmethod
|
||||||
|
def init_encoding(length):
|
||||||
div_by = torch.arange(length) / length
|
div_by = torch.arange(length) / length
|
||||||
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
|
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
|
||||||
return enc
|
return enc
|
||||||
|
@ -46,8 +57,8 @@ class FiLM(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoding = NoiseLevelEncoding(in_channels)
|
self.encoding = NoiseLevelEncoding(in_channels)
|
||||||
self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1)
|
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
|
||||||
self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
def _init_parameters(self):
|
def _init_parameters(self):
|
||||||
|
@ -74,26 +85,26 @@ class UBlock(nn.Module):
|
||||||
assert len(dilations) == 4
|
assert len(dilations) == 4
|
||||||
|
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
|
||||||
self.main_block1 = nn.ModuleList([
|
self.main_block1 = nn.ModuleList([
|
||||||
nn.Conv1d(in_channels,
|
Conv1d(in_channels,
|
||||||
hid_channels,
|
hid_channels,
|
||||||
3,
|
3,
|
||||||
dilation=dilations[0],
|
dilation=dilations[0],
|
||||||
padding=dilations[0]),
|
padding=dilations[0]),
|
||||||
nn.Conv1d(hid_channels,
|
Conv1d(hid_channels,
|
||||||
hid_channels,
|
hid_channels,
|
||||||
3,
|
3,
|
||||||
dilation=dilations[1],
|
dilation=dilations[1],
|
||||||
padding=dilations[1])
|
padding=dilations[1])
|
||||||
])
|
])
|
||||||
self.main_block2 = nn.ModuleList([
|
self.main_block2 = nn.ModuleList([
|
||||||
nn.Conv1d(hid_channels,
|
Conv1d(hid_channels,
|
||||||
hid_channels,
|
hid_channels,
|
||||||
3,
|
3,
|
||||||
dilation=dilations[2],
|
dilation=dilations[2],
|
||||||
padding=dilations[2]),
|
padding=dilations[2]),
|
||||||
nn.Conv1d(hid_channels,
|
Conv1d(hid_channels,
|
||||||
hid_channels,
|
hid_channels,
|
||||||
3,
|
3,
|
||||||
dilation=dilations[3],
|
dilation=dilations[3],
|
||||||
|
@ -129,11 +140,11 @@ class DBlock(nn.Module):
|
||||||
def __init__(self, in_channels, hid_channels, downsample_factor):
|
def __init__(self, in_channels, hid_channels, downsample_factor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.downsample_factor = downsample_factor
|
self.downsample_factor = downsample_factor
|
||||||
self.res_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
self.res_conv = Conv1d(in_channels, hid_channels, 1)
|
||||||
self.main_convs = nn.ModuleList([
|
self.main_convs = nn.ModuleList([
|
||||||
nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
||||||
nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
||||||
nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -22,14 +22,6 @@ class Wavegrad(nn.Module):
|
||||||
assert len(upsample_factors) == len(upsample_dilations)
|
assert len(upsample_factors) == len(upsample_dilations)
|
||||||
assert len(upsample_factors) == len(ublock_out_channels)
|
assert len(upsample_factors) == len(ublock_out_channels)
|
||||||
|
|
||||||
# inference time noise schedule params
|
|
||||||
self.S = 1000
|
|
||||||
beta, alpha, alpha_cum, noise_level = self._setup_noise_level()
|
|
||||||
self.register_buffer('beta', beta)
|
|
||||||
self.register_buffer('alpha', alpha)
|
|
||||||
self.register_buffer('alpha_cum', alpha_cum)
|
|
||||||
self.register_buffer('noise_level', noise_level)
|
|
||||||
|
|
||||||
# setup up-down sampling parameters
|
# setup up-down sampling parameters
|
||||||
self.hop_length = np.prod(upsample_factors)
|
self.hop_length = np.prod(upsample_factors)
|
||||||
self.upsample_factors = upsample_factors
|
self.upsample_factors = upsample_factors
|
||||||
|
@ -68,21 +60,23 @@ class Wavegrad(nn.Module):
|
||||||
# print(ic, 'last_conv--', out_channels)
|
# print(ic, 'last_conv--', out_channels)
|
||||||
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
||||||
|
|
||||||
def _setup_noise_level(self, noise_schedule=None):
|
# inference time noise schedule params
|
||||||
"""compute noise schedule parameters"""
|
self.S = 1000
|
||||||
if noise_schedule is None:
|
self.init_noise_schedule(self.S)
|
||||||
beta = np.linspace(1e-6, 0.01, self.S)
|
|
||||||
else:
|
|
||||||
beta = noise_schedule
|
|
||||||
alpha = 1 - beta
|
|
||||||
alpha_cum = np.cumprod(alpha)
|
|
||||||
noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0)
|
|
||||||
|
|
||||||
beta = torch.from_numpy(beta)
|
|
||||||
alpha = torch.from_numpy(alpha)
|
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01):
|
||||||
alpha_cum = torch.from_numpy(alpha_cum)
|
"""compute noise schedule parameters"""
|
||||||
noise_level = torch.from_numpy(noise_level.astype(np.float32))
|
device = self.last_conv.weight.device
|
||||||
return beta, alpha, alpha_cum, noise_level
|
beta = torch.linspace(min_val, max_val, num_iter).to(device)
|
||||||
|
alpha = 1 - beta
|
||||||
|
alpha_cum = alpha.cumprod(dim=0)
|
||||||
|
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
|
||||||
|
|
||||||
|
self.register_buffer('beta', beta)
|
||||||
|
self.register_buffer('alpha', alpha)
|
||||||
|
self.register_buffer('alpha_cum', alpha_cum)
|
||||||
|
self.register_buffer('noise_level', noise_level)
|
||||||
|
|
||||||
def compute_noisy_x(self, x):
|
def compute_noisy_x(self, x):
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
|
@ -94,7 +88,7 @@ class Wavegrad(nn.Module):
|
||||||
noise_scale = noise_scale.unsqueeze(1)
|
noise_scale = noise_scale.unsqueeze(1)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
||||||
return noisy_x.unsqueeze(1), noise_scale[:, 0]
|
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
|
||||||
|
|
||||||
def forward(self, x, c, noise_scale):
|
def forward(self, x, c, noise_scale):
|
||||||
assert len(c.shape) == 3 # B, C, T
|
assert len(c.shape) == 3 # B, C, T
|
||||||
|
@ -114,9 +108,8 @@ class Wavegrad(nn.Module):
|
||||||
|
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c)
|
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
|
||||||
noise_scale = torch.from_numpy(
|
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
|
||||||
self.alpha_cum**0.5).float().unsqueeze(1).to(c)
|
|
||||||
for n in range(len(self.alpha) - 1, -1, -1):
|
for n in range(len(self.alpha) - 1, -1, -1):
|
||||||
c1 = 1 / self.alpha[n]**0.5
|
c1 = 1 / self.alpha[n]**0.5
|
||||||
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
||||||
|
|
Loading…
Reference in New Issue