mirror of https://github.com/coqui-ai/TTS.git
new way of handling memory queue and enable/disable queuing in right/wrong conditions
This commit is contained in:
parent
f4eaec1264
commit
f038b1aa3f
|
@ -279,15 +279,17 @@ class Decoder(nn.Module):
|
||||||
trans_agent, forward_attn_mask, location_attn,
|
trans_agent, forward_attn_mask, location_attn,
|
||||||
separate_stopnet):
|
separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
self.r_init = r
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.max_decoder_steps = 500
|
self.max_decoder_steps = 500
|
||||||
|
self.use_memory_queue = memory_size > 0
|
||||||
self.memory_size = memory_size if memory_size > 0 else r
|
self.memory_size = memory_size if memory_size > 0 else r
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
memory_dim * self.memory_size,
|
memory_dim * self.memory_size if self.use_memory_queue else memory_dim,
|
||||||
prenet_type,
|
prenet_type,
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
out_features=[256, 128])
|
out_features=[256, 128])
|
||||||
|
@ -311,21 +313,9 @@ class Decoder(nn.Module):
|
||||||
self.decoder_rnns = nn.ModuleList(
|
self.decoder_rnns = nn.ModuleList(
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
|
||||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
|
||||||
self.stopnet = StopNet(256 + memory_dim * r)
|
|
||||||
# self.init_layers()
|
|
||||||
|
|
||||||
def init_layers(self):
|
|
||||||
torch.nn.init.xavier_uniform_(
|
|
||||||
self.project_to_decoder_in.weight,
|
|
||||||
gain=torch.nn.init.calculate_gain('linear'))
|
|
||||||
torch.nn.init.xavier_uniform_(
|
|
||||||
self.proj_to_mel.weight,
|
|
||||||
gain=torch.nn.init.calculate_gain('linear'))
|
|
||||||
|
|
||||||
def _set_r(self, new_r):
|
def _set_r(self, new_r):
|
||||||
self.r = new_r
|
self.r = new_r
|
||||||
|
@ -350,13 +340,14 @@ class Decoder(nn.Module):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
T = inputs.size(1)
|
T = inputs.size(1)
|
||||||
# go frame as zeros matrix
|
# go frame as zeros matrix
|
||||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
if self.use_memory_queue:
|
||||||
|
self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device)
|
||||||
|
else:
|
||||||
|
self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device)
|
||||||
# decoder states
|
# decoder states
|
||||||
self.attention_rnn_hidden = self.attention_rnn_init(
|
self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device)
|
||||||
inputs.data.new_zeros(B).long())
|
|
||||||
self.decoder_rnn_hiddens = [
|
self.decoder_rnn_hiddens = [
|
||||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
torch.zeros(B, 256, device=inputs.device)
|
||||||
for idx in range(len(self.decoder_rnns))
|
for idx in range(len(self.decoder_rnns))
|
||||||
]
|
]
|
||||||
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
|
@ -407,11 +398,20 @@ class Decoder(nn.Module):
|
||||||
output = output[:, : self.r * self.memory_dim]
|
output = output[:, : self.r * self.memory_dim]
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention_layer.attention_weights
|
||||||
|
|
||||||
def _update_memory_queue(self, new_memory):
|
def _update_memory_input(self, new_memory):
|
||||||
|
if self.use_memory_queue:
|
||||||
|
if self.memory_size > self.r:
|
||||||
|
# memory queue size is larger than number of frames per decoder iter
|
||||||
self.memory_input = torch.cat([
|
self.memory_input = torch.cat([
|
||||||
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
|
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
else:
|
||||||
|
# memory queue size smaller than number of frames per decoder iter
|
||||||
|
self.memory_input = new_memory[:, (self.r - self.memory_size)*self.memory_dim:]
|
||||||
|
else:
|
||||||
|
# use only the last frame prediction
|
||||||
|
self.memory_input = new_memory[:, (self.r-1) * self.memory_dim:]
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask):
|
def forward(self, inputs, memory, mask):
|
||||||
"""
|
"""
|
||||||
|
@ -437,7 +437,7 @@ class Decoder(nn.Module):
|
||||||
while len(outputs) < memory.size(0):
|
while len(outputs) < memory.size(0):
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = memory[t - 1]
|
new_memory = memory[t - 1]
|
||||||
self._update_memory_queue(new_memory)
|
self._update_memory_input(new_memory)
|
||||||
output, stop_token, attention = self.decode(inputs, mask)
|
output, stop_token, attention = self.decode(inputs, mask)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
|
@ -464,7 +464,7 @@ class Decoder(nn.Module):
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
self._update_memory_queue(new_memory)
|
self._update_memory_input(new_memory)
|
||||||
output, stop_token, attention = self.decode(inputs, None)
|
output, stop_token, attention = self.decode(inputs, None)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
|
|
Loading…
Reference in New Issue