mirror of https://github.com/coqui-ai/TTS.git
enforce monotonic attention for forward attention in eval time
This commit is contained in:
parent
ba492f43be
commit
59ba37904d
|
@ -219,16 +219,18 @@ class Attention(nn.Module):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||||
(1, 0, 0, 0)).to(inputs.device)
|
(1, 0, 0, 0)).to(inputs.device)
|
||||||
# force incremental alignment
|
# compute transition potentials
|
||||||
if not self.training:
|
|
||||||
val, n = prev_alpha.max(1)
|
|
||||||
if alignment.shape[0] == 1:
|
|
||||||
alignment[:, n+2:] = 0
|
|
||||||
else:
|
|
||||||
for b in range(alignment.shape[0]):
|
|
||||||
alignment[b, n[b]+2:]
|
|
||||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||||
self.u * prev_alpha) + 1e-8) * alignment
|
self.u * prev_alpha) + 1e-8) * alignment
|
||||||
|
# force incremental alignment - TODO: make configurable
|
||||||
|
if not self.training and alignment.shape[0] == 1:
|
||||||
|
_, n = prev_alpha.max(1)
|
||||||
|
val, n2 = alpha.max(1)
|
||||||
|
for b in range(alignment.shape[0]):
|
||||||
|
alpha[b, n+2:] = 0
|
||||||
|
alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||||
|
alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
|
||||||
|
# compute attention weights
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||||
# compute context
|
# compute context
|
||||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||||
|
|
Loading…
Reference in New Issue