From 53ec066733237dbc215b2aeb3a1a749ba2e9170f Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 12 Oct 2019 18:34:12 +0200 Subject: [PATCH] replace zeros() with a better alternative --- layers/tacotron.py | 8 ++++---- layers/tacotron2.py | 27 +++++++++++---------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 04781031..657eefe7 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -340,13 +340,13 @@ class Decoder(nn.Module): T = inputs.size(1) # go frame as zeros matrix if self.use_memory_queue: - self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device) + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim * self.memory_size) else: - self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device) + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim) # decoder states - self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device) + self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) self.decoder_rnn_hiddens = [ - torch.zeros(B, 256, device=inputs.device) + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) ] self.context_vec = inputs.data.new(B, self.in_features).zero_() diff --git a/layers/tacotron2.py b/layers/tacotron2.py index a02ff95a..ea55cbed 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -154,28 +154,23 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) - memory = torch.zeros(B, - self.mel_channels * self.r, - device=inputs.device) + memory = torch.zeros(1, device=inputs.device).repeat(B, + self.mel_channels * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): B = inputs.size(0) # T = inputs.size(1) if not keep_states: - self.query = torch.zeros(B, self.query_dim, device=inputs.device) - self.attention_rnn_cell_state = torch.zeros(B, - self.query_dim, - device=inputs.device) - self.decoder_hidden = torch.zeros(B, - self.decoder_rnn_dim, - device=inputs.device) - self.decoder_cell = torch.zeros(B, - self.decoder_rnn_dim, - device=inputs.device) - self.context = torch.zeros(B, - self.encoder_embedding_dim, - device=inputs.device) + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, + self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, + self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, + self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, + self.encoder_embedding_dim) self.inputs = inputs self.processed_inputs = self.attention.inputs_layer(inputs) self.mask = mask