mirror of https://github.com/coqui-ai/TTS.git
remove intermediate tensor asap
This commit is contained in:
parent
b6f559d315
commit
4826e7db9c
|
@ -86,15 +86,15 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
query = query.unsqueeze(1)
|
query = query.unsqueeze(1)
|
||||||
loc_conv = self.loc_conv(loc)
|
processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2))
|
||||||
loc_conv = loc_conv.transpose(1, 2)
|
|
||||||
processed_loc = self.loc_linear(loc_conv)
|
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
# cache annots
|
# cache annots
|
||||||
if self.processed_annots is None:
|
if self.processed_annots is None:
|
||||||
self.processed_annots = self.annot_layer(annot)
|
self.processed_annots = self.annot_layer(annot)
|
||||||
alignment = self.v(
|
alignment = self.v(
|
||||||
torch.tanh(processed_query + self.processed_annots + processed_loc))
|
torch.tanh(processed_query + self.processed_annots + processed_loc))
|
||||||
|
del processed_loc
|
||||||
|
del processed_query
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
@ -138,11 +138,9 @@ class AttentionRNNCell(nn.Module):
|
||||||
"""
|
"""
|
||||||
if t == 0:
|
if t == 0:
|
||||||
self.alignment_model.reset()
|
self.alignment_model.reset()
|
||||||
# Concat input query and previous context context
|
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
|
|
|
@ -403,6 +403,7 @@ class Decoder(nn.Module):
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
inputs, attention_cat, mask, t)
|
inputs, attention_cat, mask, t)
|
||||||
|
del attention_cat
|
||||||
attention_cum += attention
|
attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
|
@ -414,15 +415,19 @@ class Decoder(nn.Module):
|
||||||
# Residual connectinon
|
# Residual connectinon
|
||||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||||
decoder_output = decoder_input
|
decoder_output = decoder_input
|
||||||
|
del decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
output = torch.sigmoid(output)
|
output = torch.sigmoid(output)
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stopnet_input = torch.cat([decoder_input, output], -1)
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||||
|
del decoder_output
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
|
del stopnet_input
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
del output
|
||||||
t += 1
|
t += 1
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
|
|
Loading…
Reference in New Issue