RNN stop-token prediction

This commit is contained in:
Eren Golge 2018-05-15 08:22:42 -07:00
parent a31e60e928
commit 40f1a3d3a5
2 changed files with 24 additions and 11 deletions

View File

@ -232,12 +232,7 @@ class Decoder(nn.Module):
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = nn.Sequential(nn.Linear(256 + self.r * memory_dim, memory_dim),
nn.ReLU(),
nn.Linear(memory_dim, 256 + self.r * memory_dim),
nn.ReLU(),
nn.Linear(256 + self.r * memory_dim, 1),
nn.Sigmoid())
self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None):
"""
@ -273,6 +268,7 @@ class Decoder(nn.Module):
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
@ -304,9 +300,9 @@ class Decoder(nn.Module):
decoder_output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
stop_input = torch.cat((output, decoder_output), -1)
stop_input = output
# predict stop token
stop_token = self.stopnet(stop_input)
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
outputs += [output]
alignments += [alignment]
stop_tokens += [stop_token]
@ -327,3 +323,21 @@ class Decoder(nn.Module):
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens
class StopNet(nn.Module):
def __init__(self, r, memory_dim):
super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
self.relu = nn.ReLU()
self.linear = nn.Linear(r * memory_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, rnn_hidden):
rnn_hidden = self.rnn(inputs, rnn_hidden)
outputs = self.relu(rnn_hidden)
outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs, rnn_hidden

View File

@ -88,7 +88,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
current_lr_st = lr_decay(0.01, current_step, c.warmup_steps)
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
@ -363,8 +363,7 @@ def main(args):
c.r)
optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=0.01,
momentum=0.9, nesterov=True)
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
criterion = L1LossMasked()
criterion_st = nn.BCELoss()