From 40f1a3d3a514a24d7b8b9005410b860168401d11 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 15 May 2018 08:22:42 -0700 Subject: [PATCH] RNN stop-token prediction --- layers/tacotron.py | 30 ++++++++++++++++++++++-------- train.py | 5 ++--- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 994179df..fea5af45 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -232,12 +232,7 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - self.stopnet = nn.Sequential(nn.Linear(256 + self.r * memory_dim, memory_dim), - nn.ReLU(), - nn.Linear(memory_dim, 256 + self.r * memory_dim), - nn.ReLU(), - nn.Linear(256 + self.r * memory_dim, 1), - nn.Sigmoid()) + self.stopnet = StopNet(r, memory_dim) def forward(self, inputs, memory=None): """ @@ -273,6 +268,7 @@ class Decoder(nn.Module): decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() for _ in range(len(self.decoder_rnns))] current_context_vec = inputs.data.new(B, 256).zero_() + stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() # Time first (T_decoder, B, memory_dim) if memory is not None: memory = memory.transpose(0, 1) @@ -304,9 +300,9 @@ class Decoder(nn.Module): decoder_output = decoder_input # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) - stop_input = torch.cat((output, decoder_output), -1) + stop_input = output # predict stop token - stop_token = self.stopnet(stop_input) + stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) outputs += [output] alignments += [alignment] stop_tokens += [stop_token] @@ -327,3 +323,21 @@ class Decoder(nn.Module): outputs = torch.stack(outputs).transpose(0, 1).contiguous() stop_tokens = torch.stack(stop_tokens).transpose(0, 1) return outputs, alignments, stop_tokens + + +class StopNet(nn.Module): + + def __init__(self, r, memory_dim): + super(StopNet, self).__init__() + self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) + self.relu = nn.ReLU() + self.linear = nn.Linear(r * memory_dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, rnn_hidden): + rnn_hidden = self.rnn(inputs, rnn_hidden) + outputs = self.relu(rnn_hidden) + outputs = self.linear(outputs) + outputs = self.sigmoid(outputs) + return outputs, rnn_hidden + \ No newline at end of file diff --git a/train.py b/train.py index 8957fd6f..2516ad55 100644 --- a/train.py +++ b/train.py @@ -88,7 +88,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # setup lr current_lr = lr_decay(c.lr, current_step, c.warmup_steps) - current_lr_st = lr_decay(0.01, current_step, c.warmup_steps) + current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps) for params_group in optimizer.param_groups: params_group['lr'] = current_lr @@ -363,8 +363,7 @@ def main(args): c.r) optimizer = optim.Adam(model.parameters(), lr=c.lr) - optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=0.01, - momentum=0.9, nesterov=True) + optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) criterion = L1LossMasked() criterion_st = nn.BCELoss()