mirror of https://github.com/coqui-ai/TTS.git
guided attn #12
This commit is contained in:
parent
e87d7d8d26
commit
3161a16101
2
train.py
2
train.py
|
@ -130,7 +130,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
mel_lengths_var)
|
||||
print(M.shape)
|
||||
print(alignments.shape)
|
||||
attention_loss = criterion(M, alignments, mel_lengths_var)
|
||||
attention_loss = criterion(alignments, M, mel_lengths_var)
|
||||
loss = mel_loss + linear_loss + 0.2 * attention_loss
|
||||
|
||||
# backpass and check the grad norm
|
||||
|
|
Loading…
Reference in New Issue