wavegrad dataset update

This commit is contained in:
erogol 2020-10-26 16:47:09 +01:00
parent 5b5b9fcfdd
commit dc2825dfb2
2 changed files with 111 additions and 129 deletions

View File

@ -28,7 +28,6 @@ class WaveGradDataset(Dataset):
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
@ -64,57 +63,49 @@ class WaveGradDataset(Dataset):
def load_test_samples(self, num_samples):
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, idx):
""" load (audio, feat) couple """
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
# compute features from wav
wavpath = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
if len(audio) < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
mel = self.ap.melspectrogram(audio)
if self.use_cache and self.cache[idx] is not None:
audio = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
# load precomputed features
wavpath, feat_path = self.item_list[idx]
# correct audio length wrt segment length
if audio.shape[-1] < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
# correct the audio length wrt hop length
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
audio = np.pad(audio, (0, p), mode='constant', constant_values=0.0)
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[:mel.shape[-1] * self.hop_len]
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
if self.use_cache:
self.cache[idx] = audio
if self.return_segments:
max_mel_start = mel.shape[1] - self.feat_frame_len
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.feat_frame_len
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hop_len
audio = audio[:, audio_start:audio_start +
self.seq_len]
max_start = len(audio) - self.seq_len
start = random.randint(0, max_start)
end = start + self.seq_len
audio = audio[start:end]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)
mel = self.ap.melspectrogram(audio)
mel = mel[..., :-1]
audio = torch.from_numpy(audio).float()
mel = torch.from_numpy(mel).float().squeeze(0)
return (mel, audio)

View File

@ -1,32 +1,25 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torch.nn as nn
import torch.nn.functional as F
from math import log as ln
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.weight)
nn.init.zeros_(self.bias)
class NoiseLevelEncoding(nn.Module):
"""Noise level encoding applying same
encoding vector to all time steps. It is
different than the original implementation."""
class PositionalEncoding(nn.Module):
def __init__(self, n_channels):
super().__init__()
self.n_channels = n_channels
self.length = n_channels // 2
assert n_channels % 2 == 0
enc = self.init_encoding(self.length)
self.register_buffer('enc', enc)
def forward(self, x, noise_level):
"""
Shapes:
@ -35,41 +28,34 @@ class NoiseLevelEncoding(nn.Module):
"""
return (x + self.encoding(noise_level)[:, :, None])
@staticmethod
def init_encoding(length):
div_by = torch.arange(length) / length
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
return enc
def encoding(self, noise_level):
encoding = noise_level.unsqueeze(1) * self.enc
encoding = torch.cat(
[torch.sin(encoding), torch.cos(encoding)], dim=-1)
step = torch.arange(
self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
encoding = noise_level.unsqueeze(1) * torch.exp(
-ln(1e4) * step.unsqueeze(0))
encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
return encoding
class FiLM(nn.Module):
"""Feature-wise Linear Modulation. It combines information from
both noisy waveform and input mel-spectrogram. The FiLM module
produces both scale and bias vectors given inputs, which are
used in a UBlock for feature-wise affine transformation."""
def __init__(self, in_channels, out_channels):
def __init__(self, input_size, output_size):
super().__init__()
self.encoding = NoiseLevelEncoding(in_channels)
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
self._init_parameters()
self.encoding = PositionalEncoding(input_size)
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
self.ini_parameters()
def _init_parameters(self):
nn.init.orthogonal_(self.conv_in.weight)
nn.init.orthogonal_(self.conv_out.weight)
def ini_parameters(self):
nn.init.xavier_uniform_(self.input_conv.weight)
nn.init.xavier_uniform_(self.output_conv.weight)
nn.init.zeros_(self.input_conv.bias)
nn.init.zeros_(self.output_conv.bias)
def forward(self, x, noise_scale):
x = self.conv_in(x)
x = self.input_conv(x)
x = F.leaky_relu(x, 0.2)
x = self.encoding(x, noise_scale)
shift, scale = torch.chunk(self.conv_out(x), 2, dim=1)
shift, scale = torch.chunk(self.output_conv(x), 2, dim=1)
return shift, scale
@ -80,82 +66,87 @@ def shif_and_scale(x, scale, shift):
class UBlock(nn.Module):
def __init__(self, in_channels, hid_channels, upsample_factor, dilations):
def __init__(self, input_size, hidden_size, factor, dilation):
super().__init__()
assert len(dilations) == 4
assert isinstance(dilation, (list, tuple))
assert len(dilation) == 4
self.upsample_factor = upsample_factor
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
self.main_block1 = nn.ModuleList([
Conv1d(in_channels,
hid_channels,
self.factor = factor
self.block1 = Conv1d(input_size, hidden_size, 1)
self.block2 = nn.ModuleList([
Conv1d(input_size,
hidden_size,
3,
dilation=dilations[0],
padding=dilations[0]),
Conv1d(hid_channels,
hid_channels,
dilation=dilation[0],
padding=dilation[0]),
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilations[1],
padding=dilations[1])
dilation=dilation[1],
padding=dilation[1])
])
self.main_block2 = nn.ModuleList([
Conv1d(hid_channels,
hid_channels,
self.block3 = nn.ModuleList([
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilations[2],
padding=dilations[2]),
Conv1d(hid_channels,
hid_channels,
dilation=dilation[2],
padding=dilation[2]),
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilations[3],
padding=dilations[3])
dilation=dilation[3],
padding=dilation[3])
])
def forward(self, x, shift, scale):
upsample_size = x.shape[-1] * self.upsample_factor
x = F.interpolate(x, size=upsample_size)
res = self.shortcut_conv(x)
block1 = F.interpolate(x, size=x.shape[-1] * self.factor)
block1 = self.block1(block1)
o = F.leaky_relu(x, 0.2)
o = self.main_block1[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block1[1](o)
block2 = F.leaky_relu(x, 0.2)
block2 = F.interpolate(block2, size=x.shape[-1] * self.factor)
block2 = self.block2[0](block2)
# block2 = film_shift + film_scale * block2
block2 = shif_and_scale(block2, scale, shift)
block2 = F.leaky_relu(block2, 0.2)
block2 = self.block2[1](block2)
o = o + res
res = o
x = block1 + block2
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[1](o)
# block3 = film_shift + film_scale * x
block3 = shif_and_scale(x, scale, shift)
block3 = F.leaky_relu(block3, 0.2)
block3 = self.block3[0](block3)
# block3 = film_shift + film_scale * block3
block3 = shif_and_scale(block3, scale, shift)
block3 = F.leaky_relu(block3, 0.2)
block3 = self.block3[1](block3)
o = o + res
return o
x = x + block3
return x
class DBlock(nn.Module):
def __init__(self, in_channels, hid_channels, downsample_factor):
def __init__(self, input_size, hidden_size, factor):
super().__init__()
self.downsample_factor = downsample_factor
self.res_conv = Conv1d(in_channels, hid_channels, 1)
self.main_convs = nn.ModuleList([
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
self.factor = factor
self.residual_dense = Conv1d(input_size, hidden_size, 1)
self.conv = nn.ModuleList([
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
])
def forward(self, x):
size = x.shape[-1] // self.downsample_factor
size = x.shape[-1] // self.factor
res = self.res_conv(x)
res = F.interpolate(res, size=size)
residual = self.residual_dense(x)
residual = F.interpolate(residual, size=size)
x = F.interpolate(x, size=size)
for layer in self.conv:
x = F.leaky_relu(x, 0.2)
x = layer(x)
return x + residual
o = F.interpolate(x, size=size)
for layer in self.main_convs:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res