teacher forcing with combining

This commit is contained in:
Eren Golge 2018-02-23 08:35:53 -08:00
parent c72b8fd64c
commit b5f2181e04
2 changed files with 20 additions and 9 deletions

View File

@ -153,7 +153,7 @@ class CBHG(nn.Module):
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
@ -236,7 +236,7 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
def forward(self, inputs, memory=None, memory_lengths=None):
def forward(self, inputs, memory=None, input_lengths=None):
r"""
Decoder forward step.
@ -247,7 +247,7 @@ class Decoder(nn.Module):
inputs: Encoder outputs.
memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
input_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
Shapes:
@ -258,8 +258,8 @@ class Decoder(nn.Module):
# TODO: take this segment into Attention module.
processed_inputs = self.input_layer(inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
if input_lengths is not None:
mask = get_mask_from_lengths(processed_inputs, input_lengths)
else:
mask = None
@ -300,7 +300,17 @@ class Decoder(nn.Module):
memory_input = initial_memory
while True:
if t > 0:
memory_input = outputs[-1] if greedy else memory[t - 1]
# using harmonized teacher-forcing.
# from https://arxiv.org/abs/1707.06588
if greedy:
memory_input = outputs[-1]
else:
# combine prev. model output and prev. real target
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
# add a random noise
memory_input += torch.autograd.Variable(
torch.randn(memory_input.size())).type_as(memory_input)
# Prenet
processed_memory = self.prenet(memory_input)

View File

@ -34,12 +34,13 @@ class Tacotron(nn.Module):
encoder_outputs = self.encoder(inputs)
if self.use_memory_mask:
memory_lengths = input_lengths
input_lengths = input_lengths
else:
memory_lengths = None
input_lengths = None
# (B, T', mel_dim*r)
mel_outputs, alignments = self.decoder(
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
encoder_outputs, mel_specs, input_lengths=input_lengths)
# Post net processing below