Attn masking

This commit is contained in:
Eren G 2018-07-13 14:50:55 +02:00
parent 9f52833151
commit dac8fdffa9
6 changed files with 199 additions and 155 deletions

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module):
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations,
attention_vec, mask=None, annotations_lengths=None):
def forward(self, memory, context, rnn_state, annots,
atten, annot_lens=None):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
- context: (batch, dim)
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, max_time)
- annot_lens: (batch,)
"""
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
if self.align_model is 'b':
alignment = self.alignment_model(annotations, rnn_output)
alignment = self.alignment_model(annots, rnn_output)
else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck.
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
alignment = self.alignment_model(annots, rnn_output, atten)
if annot_lens is not None:
mask = sequence_mask(annot_lens)
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context = torch.bmm(alignment.unsqueeze(1), annotations)
context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1)
return rnn_output, context, alignment

View File

@ -1,26 +1,53 @@
import torch
from torch.nn import functional
from torch import nn
from utils.generic_utils import sequence_mask
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
class L2LossMasked(nn.Module):
class L1LossMasked(nn.Module):
def __init__(self):
super(L2LossMasked, self).__init__()
super(L1LossMasked, self).__init__()
def forward(self, input, target, length):
"""
Args:
input: A Variable containing a FloatTensor of size
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
def forward(self, input, target, length):
"""
@ -50,7 +77,7 @@ class L2LossMasked(nn.Module):
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length,
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))

View File

@ -213,7 +213,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None):
def forward(self, inputs, memory=None, input_lens=None):
"""
Decoder forward step.
@ -225,6 +225,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
input_lens (None): Time length of each input in batch.
Shapes:
- inputs: batch x time x encoder_out_dim
@ -273,7 +274,8 @@ class Decoder(nn.Module):
# attention_cum.unsqueeze(1)),
# dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention, input_lens)
# attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(

View File

@ -21,14 +21,14 @@ class Tacotron(nn.Module):
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
def forward(self, characters, mel_specs=None):
def forward(self, characters, mel_specs=None, text_lens=None):
B = characters.size(0)
inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs)
encoder_outputs, mel_specs, text_lens)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# forward pass
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input)
model.forward(text_input, mel_input, text_lengths)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -12,7 +12,7 @@ class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None):
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
@ -22,6 +22,8 @@ class AudioProcessor(object):
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters
def save_wav(self, wav, path):
@ -36,7 +38,8 @@ class AudioProcessor(object):
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
@ -63,7 +66,8 @@ class AudioProcessor(object):
return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
# D = self._stft(self.apply_preemphasis(y))
D = self._stft(y)
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
@ -72,7 +76,8 @@ class AudioProcessor(object):
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
# return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
# def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim
@ -110,7 +115,7 @@ class AudioProcessor(object):
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)