mirror of https://github.com/coqui-ai/TTS.git
remove intermediate tensor asap
This commit is contained in:
parent
b6f559d315
commit
4826e7db9c
|
@ -86,15 +86,15 @@ class LocationSensitiveAttention(nn.Module):
|
|||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
query = query.unsqueeze(1)
|
||||
loc_conv = self.loc_conv(loc)
|
||||
loc_conv = loc_conv.transpose(1, 2)
|
||||
processed_loc = self.loc_linear(loc_conv)
|
||||
processed_loc = self.loc_linear(self.loc_conv(loc).transpose(1, 2))
|
||||
processed_query = self.query_layer(query)
|
||||
# cache annots
|
||||
if self.processed_annots is None:
|
||||
self.processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(
|
||||
torch.tanh(processed_query + self.processed_annots + processed_loc))
|
||||
del processed_loc
|
||||
del processed_query
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -138,11 +138,9 @@ class AttentionRNNCell(nn.Module):
|
|||
"""
|
||||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
|
|
|
@ -403,6 +403,7 @@ class Decoder(nn.Module):
|
|||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention_cat, mask, t)
|
||||
del attention_cat
|
||||
attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
|
@ -414,15 +415,19 @@ class Decoder(nn.Module):
|
|||
# Residual connectinon
|
||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
del decoder_input
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
# predict stop token
|
||||
stopnet_input = torch.cat([decoder_input, output], -1)
|
||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||
del decoder_output
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
del stopnet_input
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
del output
|
||||
t += 1
|
||||
if memory is not None:
|
||||
if t >= T_decoder:
|
||||
|
|
Loading…
Reference in New Issue