From 527567d7ceb55487c312dcb69d957eeb512ef31f Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 12 Mar 2019 00:20:15 +0100 Subject: [PATCH] renaming --- layers/tacotron2.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index a311016e..1eccd466 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -299,15 +299,15 @@ class Decoder(nn.Module): memories = memories.transpose(0, 1) return memories - def _parse_outputs(self, outputs, gate_outputs, alignments): + def _parse_outputs(self, outputs, stop_tokens, alignments): alignments = torch.stack(alignments).transpose(0, 1) - gate_outputs = torch.stack(gate_outputs).transpose(0, 1) - gate_outputs = gate_outputs.contiguous() + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + stop_tokens = stop_tokens.contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = outputs.view( outputs.size(0), -1, self.mel_channels) outputs = outputs.transpose(1, 2) - return outputs, gate_outputs, alignments + return outputs, stop_tokens, alignments def decode(self, memory): cell_input = torch.cat((memory, self.context), -1) @@ -354,36 +354,36 @@ class Decoder(nn.Module): self._init_states(inputs, mask=mask) - outputs, gate_outputs, alignments = [], [], [] + outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: memory = memories[len(outputs)] - mel_output, gate_output, attention_weights = self.decode( + mel_output, stop_token, attention_weights = self.decode( memory) outputs += [mel_output.squeeze(1)] - gate_outputs += [gate_output.squeeze(1)] + stop_tokens += [stop_token.squeeze(1)] alignments += [attention_weights] - outputs, gate_outputs, alignments = self._parse_outputs( - outputs, gate_outputs, alignments) + outputs, stop_tokens, alignments = self._parse_outputs( + outputs, stop_tokens, alignments) - return outputs, gate_outputs, alignments + return outputs, stop_tokens, alignments def inference(self, inputs): memory = self.get_go_frame(inputs) self._init_states(inputs, mask=None) self.attention_layer.init_win_idx() - outputs, gate_outputs, alignments, t = [], [], [], 0 + outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False] while True: memory = self.prenet(memory) - mel_output, gate_output, alignment = self.decode(memory) - gate_output = torch.sigmoid(gate_output.data) + mel_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) outputs += [mel_output.squeeze(1)] - gate_outputs += [gate_output] + stop_tokens += [stop_token] alignments += [alignment] - stop_flags[0] = stop_flags[0] or gate_output > 0.5 + stop_flags[0] = stop_flags[0] or stop_token > 0.5 stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5 if all(stop_flags): break @@ -394,10 +394,10 @@ class Decoder(nn.Module): memory = mel_output t += 1 - outputs, gate_outputs, alignments = self._parse_outputs( - outputs, gate_outputs, alignments) + outputs, stop_tokens, alignments = self._parse_outputs( + outputs, stop_tokens, alignments) - return outputs, gate_outputs, alignments + return outputs, stop_tokens, alignments def inference_step(self, inputs, t, memory=None): """ @@ -408,7 +408,7 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None) memory = self.prenet(memory) - mel_output, gate_output, alignment = self.decode(memory) - gate_output = torch.sigmoid(gate_output.data) + mel_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) memory = mel_output - return mel_output, gate_output, alignment + return mel_output, stop_token, alignment