mirror of https://github.com/coqui-ai/TTS.git
memoru queueing
This commit is contained in:
parent
ebc166b54f
commit
90f0cd640b
|
@ -298,16 +298,18 @@ class Decoder(nn.Module):
|
||||||
in_features (int): input vector (encoder output) sample size.
|
in_features (int): input vector (encoder output) sample size.
|
||||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
|
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, attn_windowing):
|
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.max_decoder_steps = 500
|
self.max_decoder_steps = 500
|
||||||
|
self.memory_size = memory_size if memory_size > 0 else r
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * memory_dim * self.memory_size, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNNCell(
|
self.attention_rnn = AttentionRNNCell(
|
||||||
out_dim=128,
|
out_dim=128,
|
||||||
|
@ -325,7 +327,7 @@ class Decoder(nn.Module):
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||||
self.memory_init = nn.Embedding(1, r * memory_dim)
|
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||||
self.stopnet = StopNet(256 + memory_dim * r)
|
self.stopnet = StopNet(256 + memory_dim * r)
|
||||||
# self.init_layers()
|
# self.init_layers()
|
||||||
|
@ -359,6 +361,7 @@ class Decoder(nn.Module):
|
||||||
T = inputs.size(1)
|
T = inputs.size(1)
|
||||||
# go frame as zeros matrix
|
# go frame as zeros matrix
|
||||||
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
||||||
|
|
||||||
# decoder states
|
# decoder states
|
||||||
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
||||||
decoder_rnn_hiddens = [
|
decoder_rnn_hiddens = [
|
||||||
|
@ -404,9 +407,14 @@ class Decoder(nn.Module):
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory_input = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
else:
|
else:
|
||||||
memory_input = memory[t - 1]
|
new_memory = memory[t - 1]
|
||||||
|
# Queuing if memory size defined else use previous prediction only.
|
||||||
|
if self.memory_size > 0:
|
||||||
|
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
|
||||||
|
else:
|
||||||
|
memory_input = new_memory
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
|
|
Loading…
Reference in New Issue