From 030e9423965536379c5f02e775f348db258dd725 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 30 Apr 2018 06:46:37 -0700 Subject: [PATCH] config --- config.json | 2 +- layers/tacotron.py | 9 ++++++++- models/tacotron.py | 4 +--- notebooks/utils.py | 4 ++-- train.py | 2 ++ 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/config.json b/config.json index 190d8957..671e8561 100644 --- a/config.json +++ b/config.json @@ -11,7 +11,7 @@ "embedding_size": 256, "text_cleaner": "english_cleaners", - "epochs": 50, + "epochs": 500, "lr": 0.002, "warmup_steps": 4000, "batch_size": 32, diff --git a/layers/tacotron.py b/layers/tacotron.py index adf58b34..0305bdec 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -233,6 +233,8 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) + self.stopnet = nn.Sequential(nn.Dropout(0.3), nn.Linear(80 * self.r, 1), nn.Sigmoid()) + def forward(self, inputs, memory=None): """ @@ -277,6 +279,7 @@ class Decoder(nn.Module): memory = memory.transpose(0, 1) outputs = [] alignments = [] + stop_tokens = [] t = 0 memory_input = initial_memory while True: @@ -302,8 +305,11 @@ class Decoder(nn.Module): output = decoder_input # predict mel vectors from decoder vectors output = self.proj_to_mel(output) + # predict stop token + stop_token = self.stopnet(output) outputs += [output] alignments += [alignment] + stop_tokens += [stop_token] t += 1 if (not greedy and self.training) or (greedy and memory is not None): if t >= T_decoder: @@ -319,7 +325,8 @@ class Decoder(nn.Module): # Back to batch first alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - return outputs, alignments + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + return outputs, alignments, stop_tokens def is_end_of_frames(output, alignment, eps=0.01): # 0.2 diff --git a/models/tacotron.py b/models/tacotron.py index 9495bf7a..71253149 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -18,7 +18,6 @@ class Tacotron(nn.Module): self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(embedding_dim) self.decoder = Decoder(256, mel_dim, r) - self.stopnet = nn.Sequential(nn.Linear(80, 1), nn.Sigmoid()) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.last_linear = nn.Linear(mel_dim * 2, linear_dim) @@ -28,12 +27,11 @@ class Tacotron(nn.Module): # batch x time x dim encoder_outputs = self.encoder(inputs) # batch x time x dim*r - mel_outputs, alignments = self.decoder( + mel_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs) # Reshape # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - stop_tokens = self.stopnet(mel_outputs).squeeze() linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens diff --git a/notebooks/utils.py b/notebooks/utils.py index 5d19e204..b37a7241 100644 --- a/notebooks/utils.py +++ b/notebooks/utils.py @@ -23,7 +23,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap): torch.from_numpy(seq), volatile=True).unsqueeze(0) # mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True) - mel_out, linear_out, alignments = m.forward(chars_var) + mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var) linear_out = linear_out[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() spec = ap._denormalize(linear_out) @@ -31,7 +31,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap): wav = wav[:ap.find_endpoint(wav)] out = io.BytesIO() ap.save_wav(wav, out) - return wav, alignment, spec + return wav, alignment, spec, stop_tokens def visualize(alignment, spectrogram, CONFIG): diff --git a/train.py b/train.py index 94f7e8ec..a3a93962 100644 --- a/train.py +++ b/train.py @@ -102,6 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch): linear_spec = linear_spec.cuda() stop_target = stop_target.cuda() + stop_target = stop_target.view(B, stop_target.size(1) // c.r, -1) + stop_target = (stop_target.sum(1) > 0.0).long() # create attention mask if c.mk > 0.0: