This commit is contained in:
Eren Golge 2018-04-30 06:46:37 -07:00
parent e9bf49e1c3
commit 030e942396
5 changed files with 14 additions and 7 deletions

View File

@ -11,7 +11,7 @@
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 50,
"epochs": 500,
"lr": 0.002,
"warmup_steps": 4000,
"batch_size": 32,

View File

@ -233,6 +233,8 @@ class Decoder(nn.Module):
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = nn.Sequential(nn.Dropout(0.3), nn.Linear(80 * self.r, 1), nn.Sigmoid())
def forward(self, inputs, memory=None):
"""
@ -277,6 +279,7 @@ class Decoder(nn.Module):
memory = memory.transpose(0, 1)
outputs = []
alignments = []
stop_tokens = []
t = 0
memory_input = initial_memory
while True:
@ -302,8 +305,11 @@ class Decoder(nn.Module):
output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(output)
# predict stop token
stop_token = self.stopnet(output)
outputs += [output]
alignments += [alignment]
stop_tokens += [stop_token]
t += 1
if (not greedy and self.training) or (greedy and memory is not None):
if t >= T_decoder:
@ -319,7 +325,8 @@ class Decoder(nn.Module):
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
return outputs, alignments
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens
def is_end_of_frames(output, alignment, eps=0.01): # 0.2

View File

@ -18,7 +18,6 @@ class Tacotron(nn.Module):
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r)
self.stopnet = nn.Sequential(nn.Linear(80, 1), nn.Sigmoid())
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
@ -28,12 +27,11 @@ class Tacotron(nn.Module):
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments = self.decoder(
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
stop_tokens = self.stopnet(mel_outputs).squeeze()
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens

View File

@ -23,7 +23,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap):
torch.from_numpy(seq), volatile=True).unsqueeze(0)
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
mel_out, linear_out, alignments = m.forward(chars_var)
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
spec = ap._denormalize(linear_out)
@ -31,7 +31,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap):
wav = wav[:ap.find_endpoint(wav)]
out = io.BytesIO()
ap.save_wav(wav, out)
return wav, alignment, spec
return wav, alignment, spec, stop_tokens
def visualize(alignment, spectrogram, CONFIG):

View File

@ -102,6 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
linear_spec = linear_spec.cuda()
stop_target = stop_target.cuda()
stop_target = stop_target.view(B, stop_target.size(1) // c.r, -1)
stop_target = (stop_target.sum(1) > 0.0).long()
# create attention mask
if c.mk > 0.0: