diff --git a/layers/tacotron2.py b/layers/tacotron2.py index daea2bd8..4fe6c5b8 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -219,16 +219,18 @@ class Attention(nn.Module): # forward attention prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) - # force incremental alignment - if not self.training: - val, n = prev_alpha.max(1) - if alignment.shape[0] == 1: - alignment[:, n+2:] = 0 - else: - for b in range(alignment.shape[0]): - alignment[b, n[b]+2:] + # compute transition potentials alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment + # force incremental alignment - TODO: make configurable + if not self.training and alignment.shape[0] == 1: + _, n = prev_alpha.max(1) + val, n2 = alpha.max(1) + for b in range(alignment.shape[0]): + alpha[b, n+2:] = 0 + alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step + # compute attention weights self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) # compute context context = torch.bmm(self.alpha.unsqueeze(1), inputs)