mirror of https://github.com/coqui-ai/TTS.git
guided attn #10
This commit is contained in:
parent
ee00b3db5c
commit
62b5ee8f38
6
train.py
6
train.py
|
@ -106,10 +106,8 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
# create attention mask
|
# create attention mask
|
||||||
# TODO: vectorize
|
# TODO: vectorize
|
||||||
print(text_input_var.shape)
|
|
||||||
print(mel_spec_var.shape)
|
|
||||||
N = text_input_var.shape[1]
|
N = text_input_var.shape[1]
|
||||||
T = mel_spec_var.shape[1]
|
T = mel_spec_var.shape[1] / c.r
|
||||||
M = np.zeros([N, T])
|
M = np.zeros([N, T])
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
|
@ -117,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
M[n, t] = val
|
M[n, t] = val
|
||||||
e_x = np.exp(M - np.max(M))
|
e_x = np.exp(M - np.max(M))
|
||||||
M = e_x / e_x.sum(axis=0) # only difference
|
M = e_x / e_x.sum(axis=0) # only difference
|
||||||
M = Variable(torch.FloatTensor(M)).cuda()
|
M = Variable(torch.FloatTensor(M).t()).cuda()
|
||||||
M = torch.stack([M]*32)
|
M = torch.stack([M]*32)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
|
Loading…
Reference in New Issue