mirror of https://github.com/coqui-ai/TTS.git
fixing size mismatch
This commit is contained in:
parent
d5a909807e
commit
134603e1ec
|
@ -109,20 +109,25 @@ class CBHG(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hid_features=128,
|
||||
K=16,
|
||||
projections=[128, 128],
|
||||
conv_bank_features=128,
|
||||
conv_projections=[128, 128],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4):
|
||||
super(CBHG, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.hid_features = hid_features
|
||||
self.conv_bank_features = conv_bank_features
|
||||
self.highway_features = highway_features
|
||||
self.gru_features = gru_features
|
||||
self.conv_projections = conv_projections
|
||||
self.relu = nn.ReLU()
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList([
|
||||
BatchNormConv1d(
|
||||
in_features,
|
||||
hid_features,
|
||||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
|
@ -131,12 +136,12 @@ class CBHG(nn.Module):
|
|||
# max pooling of conv bank
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
out_features = [K * hid_features] + projections[:-1]
|
||||
activations = [self.relu] * (len(projections) - 1)
|
||||
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||
activations = [self.relu] * (len(conv_projections) - 1)
|
||||
activations += [None]
|
||||
# setup conv1d projection layers
|
||||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, projections,
|
||||
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
||||
activations):
|
||||
layer = BatchNormConv1d(
|
||||
in_size,
|
||||
|
@ -148,13 +153,20 @@ class CBHG(nn.Module):
|
|||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
# setup Highway layers
|
||||
if self.hid_features != self.in_features:
|
||||
self.pre_highway = nn.Linear(projections[-1], hid_features, bias=False)
|
||||
self.highways = nn.ModuleList(
|
||||
[Highway(hid_features, hid_features) for _ in range(num_highways)])
|
||||
if self.highway_features != conv_projections[-1]:
|
||||
self.pre_highway = nn.Linear(
|
||||
conv_projections[-1], highway_features, bias=False)
|
||||
self.highways = nn.ModuleList([
|
||||
Highway(highway_features, highway_features)
|
||||
for _ in range(num_highways)
|
||||
])
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(
|
||||
128, 128, 1, batch_first=True, bidirectional=True)
|
||||
gru_features,
|
||||
gru_features,
|
||||
1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# (B, T_in, in_features)
|
||||
|
@ -172,7 +184,7 @@ class CBHG(nn.Module):
|
|||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.hid_features * len(self.conv1d_banks)
|
||||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||
x = self.max_pool1d(x)[:, :, :T]
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
|
@ -180,7 +192,7 @@ class CBHG(nn.Module):
|
|||
x = x.transpose(1, 2)
|
||||
# Back to the original shape
|
||||
x += inputs
|
||||
if x.size(-1) != self.hid_features:
|
||||
if self.highway_features != self.conv_projections[-1]:
|
||||
x = self.pre_highway(x)
|
||||
# Residual connection
|
||||
# TODO: try residual scaling as in Deep Voice 3
|
||||
|
@ -195,10 +207,16 @@ class CBHG(nn.Module):
|
|||
|
||||
|
||||
class EncoderCBHG(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(EncoderCBHG, self).__init__()
|
||||
self.cbhg = CBHG(128, hid_features=128, K=16, projections=[128, 128])
|
||||
self.cbhg = CBHG(
|
||||
128,
|
||||
K=16,
|
||||
conv_bank_features=128,
|
||||
conv_projections=[128, 128],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
@ -226,11 +244,16 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class PostCBHG(nn.Module):
|
||||
|
||||
def __init__(self, mel_dim):
|
||||
super(PostCBHG, self).__init__()
|
||||
self.cbhg = CBHG(mel_dim, hid_features=128, K=8, projections=[256, mel_dim])
|
||||
|
||||
self.cbhg = CBHG(
|
||||
mel_dim,
|
||||
K=8,
|
||||
conv_bank_features=80,
|
||||
conv_projections=[160, mel_dim],
|
||||
highway_features=80,
|
||||
gru_features=80,
|
||||
num_highways=4)
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class Tacotron(nn.Module):
|
|||
self.encoder = Encoder(embedding_dim)
|
||||
self.decoder = Decoder(256, mel_dim, r)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Linear(256, linear_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
|
||||
|
||||
def forward(self, characters, mel_specs=None, mask=None):
|
||||
B = characters.size(0)
|
||||
|
|
|
@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
|
|||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
|
|||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
data = data.reshape((3, ) + fig.canvas.get_width_height()[::-1])
|
||||
plt.close()
|
||||
return data
|
||||
|
|
Loading…
Reference in New Issue