mirror of https://github.com/coqui-ai/TTS.git
Adding harmonized teacher-forcing
This commit is contained in:
parent
db1b710263
commit
55349f7085
|
@ -286,7 +286,15 @@ class Decoder(nn.Module):
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
if greedy:
|
||||||
|
memory_input = outputs[-1]
|
||||||
|
else:
|
||||||
|
# combine prev. model output and prev. real target
|
||||||
|
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||||
|
# add a random noise
|
||||||
|
noise = torch.autograd.Variable(
|
||||||
|
memory_input.data.new(memory_input.size()).normal_(0.0, 1.0))
|
||||||
|
memory_input = memory_input + noise
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
# (B, T', mel_dim*r)
|
# (B, T', mel_dim*r)
|
||||||
mel_outputs, alignments = self.decoder(
|
mel_outputs, alignments = self.decoder(
|
||||||
encoder_outputs, mel_specs, input_lengths=input_lengths)
|
encoder_outputs, mel_specs)
|
||||||
|
|
||||||
# Post net processing below
|
# Post net processing below
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue