Setting up network size according to the reference paper

This commit is contained in:
Eren G 2018-08-08 12:34:44 +02:00
parent d53112f52e
commit 2506cd136a
4 changed files with 45 additions and 20 deletions

View File

@ -109,18 +109,20 @@ class CBHG(nn.Module):
def __init__(self, def __init__(self,
in_features, in_features,
hid_features=128,
K=16, K=16,
projections=[128, 128], projections=[128, 128],
num_highways=4): num_highways=4):
super(CBHG, self).__init__() super(CBHG, self).__init__()
self.in_features = in_features self.in_features = in_features
self.hid_features = hid_features
self.relu = nn.ReLU() self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K # list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead # TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList([ self.conv1d_banks = nn.ModuleList([
BatchNormConv1d( BatchNormConv1d(
in_features, in_features,
in_features, hid_features,
kernel_size=k, kernel_size=k,
stride=1, stride=1,
padding=k // 2, padding=k // 2,
@ -129,7 +131,7 @@ class CBHG(nn.Module):
# max pooling of conv bank # max pooling of conv bank
# TODO: try average pooling OR larger kernel size # TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
out_features = [K * in_features] + projections[:-1] out_features = [K * hid_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) activations = [self.relu] * (len(projections) - 1)
activations += [None] activations += [None]
# setup conv1d projection layers # setup conv1d projection layers
@ -146,12 +148,13 @@ class CBHG(nn.Module):
layer_set.append(layer) layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set) self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers # setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False) if self.hid_features != self.in_features:
self.pre_highway = nn.Linear(projections[-1], hid_features, bias=False)
self.highways = nn.ModuleList( self.highways = nn.ModuleList(
[Highway(in_features, in_features) for _ in range(num_highways)]) [Highway(hid_features, hid_features) for _ in range(num_highways)])
# bi-directional GPU layer # bi-directional GPU layer
self.gru = nn.GRU( self.gru = nn.GRU(
in_features, in_features, 1, batch_first=True, bidirectional=True) 128, 128, 1, batch_first=True, bidirectional=True)
def forward(self, inputs): def forward(self, inputs):
# (B, T_in, in_features) # (B, T_in, in_features)
@ -161,7 +164,7 @@ class CBHG(nn.Module):
if x.size(-1) == self.in_features: if x.size(-1) == self.in_features:
x = x.transpose(1, 2) x = x.transpose(1, 2)
T = x.size(-1) T = x.size(-1)
# (B, in_features*K, T_in) # (B, hid_features*K, T_in)
# Concat conv1d bank outputs # Concat conv1d bank outputs
outs = [] outs = []
for conv1d in self.conv1d_banks: for conv1d in self.conv1d_banks:
@ -169,35 +172,45 @@ class CBHG(nn.Module):
out = out[:, :, :T] out = out[:, :, :T]
outs.append(out) outs.append(out)
x = torch.cat(outs, dim=1) x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks) assert x.size(1) == self.hid_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T] x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections: for conv1d in self.conv1d_projections:
x = conv1d(x) x = conv1d(x)
# (B, T_in, in_features) # (B, T_in, hid_feature)
# Back to the original shape
x = x.transpose(1, 2) x = x.transpose(1, 2)
if x.size(-1) != self.in_features: # Back to the original shape
x += inputs
if x.size(-1) != self.hid_features:
x = self.pre_highway(x) x = self.pre_highway(x)
# Residual connection # Residual connection
# TODO: try residual scaling as in Deep Voice 3 # TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers # TODO: try plain residual layers
x += inputs
for highway in self.highways: for highway in self.highways:
x = highway(x) x = highway(x)
# (B, T_in, in_features*2) # (B, T_in, hid_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3 # TODO: replace GRU with convolution as in Deep Voice 3
# self.gru.flatten_parameters() # self.gru.flatten_parameters()
outputs, _ = self.gru(x) outputs, _ = self.gru(x)
return outputs return outputs
class EncoderCBHG(nn.Module):
def __init__(self):
super(EncoderCBHG, self).__init__()
self.cbhg = CBHG(128, hid_features=128, K=16, projections=[128, 128])
def forward(self, x):
return self.cbhg(x)
class Encoder(nn.Module): class Encoder(nn.Module):
r"""Encapsulate Prenet and CBHG modules for encoder""" r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features): def __init__(self, in_features):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.prenet = Prenet(in_features, out_features=[256, 128]) self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128]) self.cbhg = EncoderCBHG()
def forward(self, inputs): def forward(self, inputs):
r""" r"""
@ -212,6 +225,16 @@ class Encoder(nn.Module):
return self.cbhg(inputs) return self.cbhg(inputs)
class PostCBHG(nn.Module):
def __init__(self, mel_dim):
super(PostCBHG, self).__init__()
self.cbhg = CBHG(mel_dim, hid_features=128, K=8, projections=[256, mel_dim])
def forward(self, x):
return self.cbhg(x)
class Decoder(nn.Module): class Decoder(nn.Module):
r"""Decoder module. r"""Decoder module.
@ -336,10 +359,10 @@ class Decoder(nn.Module):
if t >= T_decoder: if t >= T_decoder:
break break
else: else:
if t > inputs.shape[1] / 2 and stop_token > 0.6: if t > inputs.shape[1] / 4 and stop_token > 0.6:
break break
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" | | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
assert greedy or len(outputs) == T_decoder assert greedy or len(outputs) == T_decoder
# Back to batch first # Back to batch first

View File

@ -2,7 +2,7 @@
import torch import torch
from torch import nn from torch import nn
from utils.text.symbols import symbols from utils.text.symbols import symbols
from layers.tacotron import Prenet, Encoder, Decoder, CBHG from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
@ -22,8 +22,8 @@ class Tacotron(nn.Module):
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r) self.decoder = Decoder(256, mel_dim, r)
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(mel_dim * 2, linear_dim) self.last_linear = nn.Linear(256, linear_dim)
def forward(self, characters, mel_specs=None, text_lens=None): def forward(self, characters, mel_specs=None, text_lens=None):
B = characters.size(0) B = characters.size(0)

View File

@ -37,6 +37,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time = 0 avg_step_time = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True) print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
batch_n_iter = len(data_loader.dataset) / c.batch_size
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -114,9 +115,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch_time += step_time epoch_time += step_time
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
batch_n_iter,
current_step, current_step,
loss.item(), loss.item(),
linear_loss.item(), linear_loss.item(),

View File

@ -120,9 +120,9 @@ class AudioProcessor(object):
D = processor.run_lws(S.astype(np.float64).T**self.power) D = processor.run_lws(S.astype(np.float64).T**self.power)
y = processor.istft(D).astype(np.float32) y = processor.istft(D).astype(np.float32)
# Reconstruct phase # Reconstruct phase
sys.stdout = old_out
if self.preemphasis: if self.preemphasis:
return self.apply_inv_preemphasis(y) return self.apply_inv_preemphasis(y)
sys.stdout = old_out
return y return y
def _linear_to_mel(self, spectrogram): def _linear_to_mel(self, spectrogram):