new way of handling memory queue and enable/disable queuing in right/wrong conditions

This commit is contained in:
Eren Golge 2019-07-22 02:10:21 +02:00
parent f4eaec1264
commit f038b1aa3f
1 changed files with 27 additions and 27 deletions

View File

@ -279,15 +279,17 @@ class Decoder(nn.Module):
trans_agent, forward_attn_mask, location_attn,
separate_stopnet):
super(Decoder, self).__init__()
self.r_init = r
self.r = r
self.in_features = in_features
self.max_decoder_steps = 500
self.use_memory_queue = memory_size > 0
self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim
self.separate_stopnet = separate_stopnet
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(
memory_dim * self.memory_size,
memory_dim * self.memory_size if self.use_memory_queue else memory_dim,
prenet_type,
prenet_dropout,
out_features=[256, 128])
@ -311,21 +313,9 @@ class Decoder(nn.Module):
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
# learn init values instead of zero init.
self.attention_rnn_init = nn.Embedding(1, 256)
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
self.decoder_rnn_inits = nn.Embedding(2, 256)
self.stopnet = StopNet(256 + memory_dim * r)
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.project_to_decoder_in.weight,
gain=torch.nn.init.calculate_gain('linear'))
torch.nn.init.xavier_uniform_(
self.proj_to_mel.weight,
gain=torch.nn.init.calculate_gain('linear'))
self.stopnet = StopNet(256 + memory_dim * self.r_init)
def _set_r(self, new_r):
self.r = new_r
@ -350,13 +340,14 @@ class Decoder(nn.Module):
B = inputs.size(0)
T = inputs.size(1)
# go frame as zeros matrix
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
if self.use_memory_queue:
self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device)
else:
self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device)
# decoder states
self.attention_rnn_hidden = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device)
self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
torch.zeros(B, 256, device=inputs.device)
for idx in range(len(self.decoder_rnns))
]
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
@ -407,11 +398,20 @@ class Decoder(nn.Module):
output = output[:, : self.r * self.memory_dim]
return output, stop_token, self.attention_layer.attention_weights
def _update_memory_queue(self, new_memory):
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
],
dim=-1)
def _update_memory_input(self, new_memory):
if self.use_memory_queue:
if self.memory_size > self.r:
# memory queue size is larger than number of frames per decoder iter
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
],
dim=-1)
else:
# memory queue size smaller than number of frames per decoder iter
self.memory_input = new_memory[:, (self.r - self.memory_size)*self.memory_dim:]
else:
# use only the last frame prediction
self.memory_input = new_memory[:, (self.r-1) * self.memory_dim:]
def forward(self, inputs, memory, mask):
"""
@ -437,7 +437,7 @@ class Decoder(nn.Module):
while len(outputs) < memory.size(0):
if t > 0:
new_memory = memory[t - 1]
self._update_memory_queue(new_memory)
self._update_memory_input(new_memory)
output, stop_token, attention = self.decode(inputs, mask)
outputs += [output]
attentions += [attention]
@ -464,7 +464,7 @@ class Decoder(nn.Module):
while True:
if t > 0:
new_memory = outputs[-1]
self._update_memory_queue(new_memory)
self._update_memory_input(new_memory)
output, stop_token, attention = self.decode(inputs, None)
stop_token = torch.sigmoid(stop_token.data)
outputs += [output]