mirror of https://github.com/coqui-ai/TTS.git
Fix bug in Graves Attn
On my machine at Graves attention the variable self.J ( self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5) is a LongTensor, but it must be a float tensor. So I get the following error: Traceback (most recent call last): File "train.py", line 704, in <module> main(args) File "train.py", line 619, in main global_step, epoch) File "train.py", line 170, in train text_input, text_lengths, mel_input, speaker_embeddings=speaker_embeddings) File "/home/edresson/anaconda3/envs/TTS2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__ result = self.forward(*input, **kwargs) File "/mnt/edresson/DD/TTS/voice-clonning/TTS/tts_namespace/TTS/models/tacotron.py", line 121, in forward self.speaker_embeddings_projected) File "/home/edresson/anaconda3/envs/TTS2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__ result = self.forward(*input, **kwargs) File "/mnt/edresson/DD/TTS/voice-clonning/TTS/tts_namespace/TTS/layers/tacotron.py", line 435, in forward output, stop_token, attention = self.decode(inputs, mask) File "/mnt/edresson/DD/TTS/voice-clonning/TTS/tts_namespace/TTS/layers/tacotron.py", line 367, in decode self.attention_rnn_hidden, inputs, self.processed_inputs, mask) File "/home/edresson/anaconda3/envs/TTS2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__ result = self.forward(*input, **kwargs) File "/mnt/edresson/DD/TTS/voice-clonning/TTS/tts_namespace/TTS/layers/common_layers.py", line 180, in forward phi_t = g_t.unsqueeze(-1) * (1.0 / (1.0 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) RuntimeError: expected type torch.cuda.FloatTensor but got torch.cuda.LongTensor In addition the + 0.5 operation is canceled if it is a LongTensor. Test: >>> torch.arange(0, 10) tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> torch.arange(0, 10) + 0.5 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> torch.arange(0, 10.0) + 0.5 tensor([0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, 7.5000, 8.5000, 9.5000]) To resolve this I forced the arrange range to float: self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
This commit is contained in:
parent
f7b1cad9ee
commit
cce13ee245
|
@ -138,7 +138,7 @@ class GravesAttention(nn.Module):
|
|||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
||||
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue