mirror of https://github.com/coqui-ai/TTS.git
memoru queueing
This commit is contained in:
parent
ebc166b54f
commit
90f0cd640b
|
@ -298,16 +298,18 @@ class Decoder(nn.Module):
|
|||
in_features (int): input vector (encoder output) sample size.
|
||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||
r (int): number of outputs per time step.
|
||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, attn_windowing):
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.max_decoder_steps = 500
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
self.prenet = Prenet(memory_dim * memory_dim * self.memory_size, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
|
@ -325,7 +327,7 @@ class Decoder(nn.Module):
|
|||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
# learn init values instead of zero init.
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, r * memory_dim)
|
||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
# self.init_layers()
|
||||
|
@ -359,6 +361,7 @@ class Decoder(nn.Module):
|
|||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
# decoder states
|
||||
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
||||
decoder_rnn_hiddens = [
|
||||
|
@ -404,9 +407,14 @@ class Decoder(nn.Module):
|
|||
while True:
|
||||
if t > 0:
|
||||
if memory is None:
|
||||
memory_input = outputs[-1]
|
||||
new_memory = outputs[-1]
|
||||
else:
|
||||
memory_input = memory[t - 1]
|
||||
new_memory = memory[t - 1]
|
||||
# Queuing if memory size defined else use previous prediction only.
|
||||
if self.memory_size > 0:
|
||||
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
|
||||
else:
|
||||
memory_input = new_memory
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
|
|
Loading…
Reference in New Issue