fixing size mismatch

This commit is contained in:
Eren 2018-08-10 18:47:09 +02:00
parent 9100e5762a
commit 3b2654203d
3 changed files with 45 additions and 22 deletions

View File

@ -109,20 +109,25 @@ class CBHG(nn.Module):
def __init__(self,
in_features,
hid_features=128,
K=16,
projections=[128, 128],
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4):
super(CBHG, self).__init__()
self.in_features = in_features
self.hid_features = hid_features
self.conv_bank_features = conv_bank_features
self.highway_features = highway_features
self.gru_features = gru_features
self.conv_projections = conv_projections
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList([
BatchNormConv1d(
in_features,
hid_features,
conv_bank_features,
kernel_size=k,
stride=1,
padding=k // 2,
@ -131,12 +136,12 @@ class CBHG(nn.Module):
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
out_features = [K * hid_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1)
out_features = [K * conv_bank_features] + conv_projections[:-1]
activations = [self.relu] * (len(conv_projections) - 1)
activations += [None]
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections,
for (in_size, out_size, ac) in zip(out_features, conv_projections,
activations):
layer = BatchNormConv1d(
in_size,
@ -148,13 +153,20 @@ class CBHG(nn.Module):
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
if self.hid_features != self.in_features:
self.pre_highway = nn.Linear(projections[-1], hid_features, bias=False)
self.highways = nn.ModuleList(
[Highway(hid_features, hid_features) for _ in range(num_highways)])
if self.highway_features != conv_projections[-1]:
self.pre_highway = nn.Linear(
conv_projections[-1], highway_features, bias=False)
self.highways = nn.ModuleList([
Highway(highway_features, highway_features)
for _ in range(num_highways)
])
# bi-directional GPU layer
self.gru = nn.GRU(
128, 128, 1, batch_first=True, bidirectional=True)
gru_features,
gru_features,
1,
batch_first=True,
bidirectional=True)
def forward(self, inputs):
# (B, T_in, in_features)
@ -172,7 +184,7 @@ class CBHG(nn.Module):
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.hid_features * len(self.conv1d_banks)
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
@ -180,7 +192,7 @@ class CBHG(nn.Module):
x = x.transpose(1, 2)
# Back to the original shape
x += inputs
if x.size(-1) != self.hid_features:
if self.highway_features != self.conv_projections[-1]:
x = self.pre_highway(x)
# Residual connection
# TODO: try residual scaling as in Deep Voice 3
@ -195,10 +207,16 @@ class CBHG(nn.Module):
class EncoderCBHG(nn.Module):
def __init__(self):
super(EncoderCBHG, self).__init__()
self.cbhg = CBHG(128, hid_features=128, K=16, projections=[128, 128])
self.cbhg = CBHG(
128,
K=16,
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4)
def forward(self, x):
return self.cbhg(x)
@ -226,11 +244,16 @@ class Encoder(nn.Module):
class PostCBHG(nn.Module):
def __init__(self, mel_dim):
super(PostCBHG, self).__init__()
self.cbhg = CBHG(mel_dim, hid_features=128, K=8, projections=[256, mel_dim])
self.cbhg = CBHG(
mel_dim,
K=8,
conv_bank_features=80,
conv_projections=[160, mel_dim],
highway_features=80,
gru_features=80,
num_highways=4)
def forward(self, x):
return self.cbhg(x)

View File

@ -23,7 +23,7 @@ class Tacotron(nn.Module):
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(256, linear_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0)

View File

@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
plt.close()
return data
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
plt.close()
return data