mirror of https://github.com/coqui-ai/TTS.git
linter fix
This commit is contained in:
parent
77f5fd0584
commit
ea32f2368d
|
@ -162,15 +162,16 @@ class Decoder(nn.Module):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
|
self.query = torch.zeros(1, device=inputs.device).repeat(
|
||||||
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B,
|
B, self.query_dim)
|
||||||
self.query_dim)
|
self.attention_rnn_cell_state = torch.zeros(
|
||||||
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B,
|
1, device=inputs.device).repeat(B, self.query_dim)
|
||||||
self.decoder_rnn_dim)
|
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(
|
||||||
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B,
|
B, self.decoder_rnn_dim)
|
||||||
self.decoder_rnn_dim)
|
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(
|
||||||
self.context = torch.zeros(1, device=inputs.device).repeat(B,
|
B, self.decoder_rnn_dim)
|
||||||
self.encoder_embedding_dim)
|
self.context = torch.zeros(1, device=inputs.device).repeat(
|
||||||
|
B, self.encoder_embedding_dim)
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
@ -277,7 +278,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
if len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -317,7 +318,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
if len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue