mirror of https://github.com/coqui-ai/TTS.git
teacher forcing with combining
This commit is contained in:
parent
91ce166b8a
commit
2dfd22e1d5
|
@ -236,7 +236,7 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None, memory_lengths=None):
|
def forward(self, inputs, memory=None, input_lengths=None):
|
||||||
r"""
|
r"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ class Decoder(nn.Module):
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
memory: Decoder memory (autoregression. If None (at eval-time),
|
memory: Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs.
|
||||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
input_lengths: Encoder output (memory) lengths. If not None, used for
|
||||||
attention masking.
|
attention masking.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -258,8 +258,8 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
# TODO: take this segment into Attention module.
|
# TODO: take this segment into Attention module.
|
||||||
processed_inputs = self.input_layer(inputs)
|
processed_inputs = self.input_layer(inputs)
|
||||||
if memory_lengths is not None:
|
if input_lengths is not None:
|
||||||
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
mask = get_mask_from_lengths(processed_inputs, input_lengths)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
|
@ -300,7 +300,17 @@ class Decoder(nn.Module):
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
# using harmonized teacher-forcing.
|
||||||
|
# from https://arxiv.org/abs/1707.06588
|
||||||
|
if greedy:
|
||||||
|
memory_input = outputs[-1]
|
||||||
|
else:
|
||||||
|
# combine prev. model output and prev. real target
|
||||||
|
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||||
|
# add a random noise
|
||||||
|
memory_input += torch.autograd.Variable(
|
||||||
|
torch.randn(memory_input.size())).type_as(memory_input)
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
|
||||||
|
|
|
@ -34,12 +34,13 @@ class Tacotron(nn.Module):
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
if self.use_memory_mask:
|
if self.use_memory_mask:
|
||||||
memory_lengths = input_lengths
|
input_lengths = input_lengths
|
||||||
else:
|
else:
|
||||||
memory_lengths = None
|
input_lengths = None
|
||||||
|
|
||||||
# (B, T', mel_dim*r)
|
# (B, T', mel_dim*r)
|
||||||
mel_outputs, alignments = self.decoder(
|
mel_outputs, alignments = self.decoder(
|
||||||
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
|
encoder_outputs, mel_specs, input_lengths=input_lengths)
|
||||||
|
|
||||||
# Post net processing below
|
# Post net processing below
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue