mirror of https://github.com/coqui-ai/TTS.git
teacher forcing with combining
This commit is contained in:
parent
c72b8fd64c
commit
b5f2181e04
|
@ -153,7 +153,7 @@ class CBHG(nn.Module):
|
|||
out = conv1d(x)
|
||||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
|
||||
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||
|
||||
|
@ -236,7 +236,7 @@ class Decoder(nn.Module):
|
|||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
|
||||
def forward(self, inputs, memory=None, memory_lengths=None):
|
||||
def forward(self, inputs, memory=None, input_lengths=None):
|
||||
r"""
|
||||
Decoder forward step.
|
||||
|
||||
|
@ -247,7 +247,7 @@ class Decoder(nn.Module):
|
|||
inputs: Encoder outputs.
|
||||
memory: Decoder memory (autoregression. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs.
|
||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
||||
input_lengths: Encoder output (memory) lengths. If not None, used for
|
||||
attention masking.
|
||||
|
||||
Shapes:
|
||||
|
@ -258,8 +258,8 @@ class Decoder(nn.Module):
|
|||
|
||||
# TODO: take this segment into Attention module.
|
||||
processed_inputs = self.input_layer(inputs)
|
||||
if memory_lengths is not None:
|
||||
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
||||
if input_lengths is not None:
|
||||
mask = get_mask_from_lengths(processed_inputs, input_lengths)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
|
@ -300,7 +300,17 @@ class Decoder(nn.Module):
|
|||
memory_input = initial_memory
|
||||
while True:
|
||||
if t > 0:
|
||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||
# using harmonized teacher-forcing.
|
||||
# from https://arxiv.org/abs/1707.06588
|
||||
if greedy:
|
||||
memory_input = outputs[-1]
|
||||
else:
|
||||
# combine prev. model output and prev. real target
|
||||
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||
# add a random noise
|
||||
memory_input += torch.autograd.Variable(
|
||||
torch.randn(memory_input.size())).type_as(memory_input)
|
||||
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
|
||||
|
|
|
@ -34,12 +34,13 @@ class Tacotron(nn.Module):
|
|||
encoder_outputs = self.encoder(inputs)
|
||||
|
||||
if self.use_memory_mask:
|
||||
memory_lengths = input_lengths
|
||||
input_lengths = input_lengths
|
||||
else:
|
||||
memory_lengths = None
|
||||
input_lengths = None
|
||||
|
||||
# (B, T', mel_dim*r)
|
||||
mel_outputs, alignments = self.decoder(
|
||||
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
|
||||
encoder_outputs, mel_specs, input_lengths=input_lengths)
|
||||
|
||||
# Post net processing below
|
||||
|
||||
|
|
Loading…
Reference in New Issue