mirror of https://github.com/coqui-ai/TTS.git
replace zeros() with a better alternative
This commit is contained in:
parent
fbfa20e3b3
commit
53ec066733
|
@ -340,13 +340,13 @@ class Decoder(nn.Module):
|
||||||
T = inputs.size(1)
|
T = inputs.size(1)
|
||||||
# go frame as zeros matrix
|
# go frame as zeros matrix
|
||||||
if self.use_memory_queue:
|
if self.use_memory_queue:
|
||||||
self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device)
|
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim * self.memory_size)
|
||||||
else:
|
else:
|
||||||
self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device)
|
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim)
|
||||||
# decoder states
|
# decoder states
|
||||||
self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device)
|
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||||
self.decoder_rnn_hiddens = [
|
self.decoder_rnn_hiddens = [
|
||||||
torch.zeros(B, 256, device=inputs.device)
|
torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||||
for idx in range(len(self.decoder_rnns))
|
for idx in range(len(self.decoder_rnns))
|
||||||
]
|
]
|
||||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
|
|
|
@ -154,28 +154,23 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
def get_go_frame(self, inputs):
|
def get_go_frame(self, inputs):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
memory = torch.zeros(B,
|
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
self.mel_channels * self.r,
|
self.mel_channels * self.r)
|
||||||
device=inputs.device)
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def _init_states(self, inputs, mask, keep_states=False):
|
def _init_states(self, inputs, mask, keep_states=False):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.query = torch.zeros(B, self.query_dim, device=inputs.device)
|
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
|
||||||
self.attention_rnn_cell_state = torch.zeros(B,
|
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
self.query_dim,
|
self.query_dim)
|
||||||
device=inputs.device)
|
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
self.decoder_hidden = torch.zeros(B,
|
self.decoder_rnn_dim)
|
||||||
self.decoder_rnn_dim,
|
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
device=inputs.device)
|
self.decoder_rnn_dim)
|
||||||
self.decoder_cell = torch.zeros(B,
|
self.context = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
self.decoder_rnn_dim,
|
self.encoder_embedding_dim)
|
||||||
device=inputs.device)
|
|
||||||
self.context = torch.zeros(B,
|
|
||||||
self.encoder_embedding_dim,
|
|
||||||
device=inputs.device)
|
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
Loading…
Reference in New Issue