mirror of https://github.com/coqui-ai/TTS.git
guided attn #6
This commit is contained in:
parent
aea26bf55a
commit
60a0721bf3
2
train.py
2
train.py
|
@ -115,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
M[n, t] = val
|
||||
e_x = np.exp(M - np.max(M))
|
||||
M = e_x / e_x.sum(axis=0) # only difference
|
||||
M = Variable(torch.FloatTensor(M))
|
||||
M = Variable(torch.FloatTensor(M)).cuda()
|
||||
M = torch.stack([M]*32)
|
||||
|
||||
# forward pass
|
||||
|
|
Loading…
Reference in New Issue