diff --git a/train.py b/train.py index 2927dd88..051fb8ad 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # create attention mask N = text_input_var.shape[1] T = mel_spec_var.shape[1] // c.r - M = create_attn_mask(N, T, g) + M = create_attn_mask(N, T, 0.03) # forward pass mel_output, linear_output, alignments =\