mirror of https://github.com/coqui-ai/TTS.git
Force alignment of forward attention
This commit is contained in:
parent
f08be38b9f
commit
3a4a3e571a
|
@ -219,6 +219,10 @@ class Attention(nn.Module):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
|
||||||
(1, 0, 0, 0)).to(inputs.device)
|
(1, 0, 0, 0)).to(inputs.device)
|
||||||
|
# force incremental alignment
|
||||||
|
if not self.training:
|
||||||
|
val, n = prev_alpha.max(1)
|
||||||
|
alignment[:, n+2 :] = 0
|
||||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||||
self.u * prev_alpha) + 1e-8) * alignment
|
self.u * prev_alpha) + 1e-8) * alignment
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||||
|
@ -500,7 +504,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
if stop_count > 5:
|
if stop_count > 10:
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
|
Loading…
Reference in New Issue