Attn masking

This commit is contained in:
Eren G 2018-07-13 14:50:55 +02:00
parent 9f52833151
commit dac8fdffa9
6 changed files with 199 additions and 155 deletions

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module):
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations,
attention_vec, mask=None, annotations_lengths=None):
def forward(self, memory, context, rnn_state, annots,
atten, annot_lens=None):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
- context: (batch, dim)
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, max_time)
- annot_lens: (batch,)
"""
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
if self.align_model is 'b':
alignment = self.alignment_model(annotations, rnn_output)
alignment = self.alignment_model(annots, rnn_output)
else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck.
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
alignment = self.alignment_model(annots, rnn_output, atten)
if annot_lens is not None:
mask = sequence_mask(annot_lens)
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context = torch.bmm(alignment.unsqueeze(1), annotations)
context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1)
return rnn_output, context, alignment

View File

@ -1,26 +1,53 @@
import torch
from torch.nn import functional
from torch import nn
from utils.generic_utils import sequence_mask
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
class L2LossMasked(nn.Module):
class L1LossMasked(nn.Module):
def __init__(self):
super(L2LossMasked, self).__init__()
super(L1LossMasked, self).__init__()
def forward(self, input, target, length):
"""
Args:
input: A Variable containing a FloatTensor of size
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
def forward(self, input, target, length):
"""
@ -48,9 +75,9 @@ class L2LossMasked(nn.Module):
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length,
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))

View File

@ -213,7 +213,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None):
def forward(self, inputs, memory=None, input_lens=None):
"""
Decoder forward step.
@ -225,6 +225,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
input_lens (None): Time length of each input in batch.
Shapes:
- inputs: batch x time x encoder_out_dim
@ -273,7 +274,8 @@ class Decoder(nn.Module):
# attention_cum.unsqueeze(1)),
# dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention, input_lens)
# attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(

View File

@ -21,14 +21,14 @@ class Tacotron(nn.Module):
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
def forward(self, characters, mel_specs=None):
def forward(self, characters, mel_specs=None, text_lens=None):
B = characters.size(0)
inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs)
encoder_outputs, mel_specs, text_lens)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# forward pass
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input)
model.forward(text_input, mel_input, text_lengths)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -1,122 +1,127 @@
import os
import librosa
import pickle
import copy
import numpy as np
from scipy import signal
_mel_basis = None
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.griffin_lim_iters = griffin_lim_iters
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def _stft_parameters(self, ):
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
return n_fft, hop_length, win_length
def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x))
def _db_to_amp(self, x):
return np.power(10.0, x * 0.05)
def apply_preemphasis(self, x):
return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x):
return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
# def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim
# Based on https://github.com/librosa/librosa/issues/434
# '''
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
# S_complex = np.abs(S).astype(np.complex)
# y = self._istft(S_complex * angles)
# for i in range(self.griffin_lim_iters):
# angles = np.exp(1j * np.angle(self._stft(y)))
# y = self._istft(S_complex * angles)
# return y
def _griffin_lim(self, S):
'''Applies Griffin-Lim's raw.
'''
S_best = copy.deepcopy(S)
for i in range(self.griffin_lim_iters):
S_t = self._istft(S_best)
est = self._stft(S_t)
phase = est / np.maximum(1e-8, np.abs(est))
S_best = S * phase
S_t = self._istft(S_best)
y = np.real(S_t)
return y
def melspectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)
import os
import librosa
import pickle
import copy
import numpy as np
from scipy import signal
_mel_basis = None
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def _stft_parameters(self, ):
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
return n_fft, hop_length, win_length
def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x))
def _db_to_amp(self, x):
return np.power(10.0, x * 0.05)
def apply_preemphasis(self, x):
return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x):
return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y):
# D = self._stft(self.apply_preemphasis(y))
D = self._stft(y)
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
# return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
# def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim
# Based on https://github.com/librosa/librosa/issues/434
# '''
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
# S_complex = np.abs(S).astype(np.complex)
# y = self._istft(S_complex * angles)
# for i in range(self.griffin_lim_iters):
# angles = np.exp(1j * np.angle(self._stft(y)))
# y = self._istft(S_complex * angles)
# return y
def _griffin_lim(self, S):
'''Applies Griffin-Lim's raw.
'''
S_best = copy.deepcopy(S)
for i in range(self.griffin_lim_iters):
S_t = self._istft(S_best)
est = self._stft(S_t)
phase = est / np.maximum(1e-8, np.abs(est))
S_best = S * phase
S_t = self._istft(S_best)
y = np.real(S_t)
return y
def melspectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)