mirror of https://github.com/coqui-ai/TTS.git
Adding harmonized teacher-forcing
This commit is contained in:
parent
db1b710263
commit
55349f7085
|
@ -286,7 +286,15 @@ class Decoder(nn.Module):
|
|||
memory_input = initial_memory
|
||||
while True:
|
||||
if t > 0:
|
||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||
if greedy:
|
||||
memory_input = outputs[-1]
|
||||
else:
|
||||
# combine prev. model output and prev. real target
|
||||
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||
# add a random noise
|
||||
noise = torch.autograd.Variable(
|
||||
memory_input.data.new(memory_input.size()).normal_(0.0, 1.0))
|
||||
memory_input = memory_input + noise
|
||||
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
|
@ -338,4 +346,4 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
def is_end_of_frames(output, eps=0.2): #0.2
|
||||
return (output.data <= eps).all()
|
||||
return (output.data <= eps).all()
|
|
@ -34,7 +34,7 @@ class Tacotron(nn.Module):
|
|||
|
||||
# (B, T', mel_dim*r)
|
||||
mel_outputs, alignments = self.decoder(
|
||||
encoder_outputs, mel_specs, input_lengths=input_lengths)
|
||||
encoder_outputs, mel_specs)
|
||||
|
||||
# Post net processing below
|
||||
|
||||
|
|
Loading…
Reference in New Issue