Attn masking

This commit is contained in:
Eren G 2018-07-13 14:50:55 +02:00
parent 9f52833151
commit dac8fdffa9
6 changed files with 199 additions and 155 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module): class BahdanauAttention(nn.Module):
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations, def forward(self, memory, context, rnn_state, annots,
attention_vec, mask=None, annotations_lengths=None): atten, annot_lens=None):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
- context: (batch, dim)
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, max_time)
- annot_lens: (batch,)
"""
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN # Feed it to RNN
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
# (batch, max_time) # (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j) # e_{ij} = a(s_{i-1}, h_j)
if self.align_model is 'b': if self.align_model is 'b':
alignment = self.alignment_model(annotations, rnn_output) alignment = self.alignment_model(annots, rnn_output)
else: else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec) alignment = self.alignment_model(annots, rnn_output, atten)
# TODO: needs recheck. if annot_lens is not None:
if mask is not None: mask = sequence_mask(annot_lens)
mask = mask.view(query.size(0), -1) mask = mask.view(memory.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight
alignment = F.softmax(alignment, dim=-1) alignment = F.softmax(alignment, dim=-1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context = torch.bmm(alignment.unsqueeze(1), annotations) context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1) context = context.squeeze(1)
return rnn_output, context, alignment return rnn_output, context, alignment

View File

@ -1,26 +1,53 @@
import torch import torch
from torch.nn import functional from torch.nn import functional
from torch import nn from torch import nn
from utils.generic_utils import sequence_mask
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 class L1LossMasked(nn.Module):
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
class L2LossMasked(nn.Module):
def __init__(self): def __init__(self):
super(L2LossMasked, self).__init__() super(L1LossMasked, self).__init__()
def forward(self, input, target, length):
"""
Args:
input: A Variable containing a FloatTensor of size
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
def forward(self, input, target, length): def forward(self, input, target, length):
""" """
@ -50,7 +77,7 @@ class L2LossMasked(nn.Module):
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length, mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2) max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))

View File

@ -213,7 +213,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None): def forward(self, inputs, memory=None, input_lens=None):
""" """
Decoder forward step. Decoder forward step.
@ -225,6 +225,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
input_lens (None): Time length of each input in batch.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -273,7 +274,8 @@ class Decoder(nn.Module):
# attention_cum.unsqueeze(1)), # attention_cum.unsqueeze(1)),
# dim=1) # dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention) processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention, input_lens)
# attention_cum += attention # attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(

View File

@ -21,14 +21,14 @@ class Tacotron(nn.Module):
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim) self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
def forward(self, characters, mel_specs=None): def forward(self, characters, mel_specs=None, text_lens=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs) encoder_outputs, mel_specs, text_lens)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash) count_parameters, check_update, get_commit_hash)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda() text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() linear_input = linear_input.cuda()
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input) model.forward(text_input, mel_input, text_lengths)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -12,7 +12,7 @@ class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None): min_mel_freq, max_mel_freq, griffin_lim_iters=None):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
self.min_level_db = min_level_db self.min_level_db = min_level_db
@ -22,6 +22,8 @@ class AudioProcessor(object):
self.ref_level_db = ref_level_db self.ref_level_db = ref_level_db
self.num_freq = num_freq self.num_freq = num_freq
self.power = power self.power = power
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters self.griffin_lim_iters = griffin_lim_iters
def save_wav(self, wav, path): def save_wav(self, wav, path):
@ -36,7 +38,8 @@ class AudioProcessor(object):
def _build_mel_basis(self, ): def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S): def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
@ -63,7 +66,8 @@ class AudioProcessor(object):
return signal.lfilter([1], [1, -self.preemphasis], x) return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y): def spectrogram(self, y):
D = self._stft(self.apply_preemphasis(y)) # D = self._stft(self.apply_preemphasis(y))
D = self._stft(y)
S = self._amp_to_db(np.abs(D)) - self.ref_level_db S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S) return self._normalize(S)
@ -72,7 +76,8 @@ class AudioProcessor(object):
S = self._denormalize(spectrogram) S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase # Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) # return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
# def _griffin_lim(self, S): # def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim # '''librosa implementation of Griffin-Lim
@ -110,7 +115,7 @@ class AudioProcessor(object):
def _istft(self, y): def _istft(self, y):
_, hop_length, win_length = self._stft_parameters() _, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann') return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec) window_length = int(self.sample_rate * min_silence_sec)