Attn masking

This commit is contained in:
Eren G 2018-07-13 14:50:55 +02:00
parent 9f52833151
commit dac8fdffa9
6 changed files with 199 additions and 155 deletions

View File

@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module): class BahdanauAttention(nn.Module):
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations, def forward(self, memory, context, rnn_state, annots,
attention_vec, mask=None, annotations_lengths=None): atten, annot_lens=None):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
- context: (batch, dim)
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, max_time)
- annot_lens: (batch,)
"""
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN # Feed it to RNN
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
# (batch, max_time) # (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j) # e_{ij} = a(s_{i-1}, h_j)
if self.align_model is 'b': if self.align_model is 'b':
alignment = self.alignment_model(annotations, rnn_output) alignment = self.alignment_model(annots, rnn_output)
else: else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec) alignment = self.alignment_model(annots, rnn_output, atten)
# TODO: needs recheck. if annot_lens is not None:
if mask is not None: mask = sequence_mask(annot_lens)
mask = mask.view(query.size(0), -1) mask = mask.view(memory.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight
alignment = F.softmax(alignment, dim=-1) alignment = F.softmax(alignment, dim=-1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context = torch.bmm(alignment.unsqueeze(1), annotations) context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1) context = context.squeeze(1)
return rnn_output, context, alignment return rnn_output, context, alignment

View File

@ -1,26 +1,53 @@
import torch import torch
from torch.nn import functional from torch.nn import functional
from torch import nn from torch import nn
from utils.generic_utils import sequence_mask
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 class L1LossMasked(nn.Module):
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
class L2LossMasked(nn.Module):
def __init__(self): def __init__(self):
super(L2LossMasked, self).__init__() super(L1LossMasked, self).__init__()
def forward(self, input, target, length):
"""
Args:
input: A Variable containing a FloatTensor of size
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
def forward(self, input, target, length): def forward(self, input, target, length):
""" """
@ -48,9 +75,9 @@ class L2LossMasked(nn.Module):
reduce=False) reduce=False)
# losses: (batch, max_len, dim) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length, mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2) max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))

View File

@ -213,7 +213,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None): def forward(self, inputs, memory=None, input_lens=None):
""" """
Decoder forward step. Decoder forward step.
@ -225,6 +225,7 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
input_lens (None): Time length of each input in batch.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -273,7 +274,8 @@ class Decoder(nn.Module):
# attention_cum.unsqueeze(1)), # attention_cum.unsqueeze(1)),
# dim=1) # dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention) processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention, input_lens)
# attention_cum += attention # attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(

View File

@ -21,14 +21,14 @@ class Tacotron(nn.Module):
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, linear_dim) self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
def forward(self, characters, mel_specs=None): def forward(self, characters, mel_specs=None, text_lens=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs) encoder_outputs, mel_specs, text_lens)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash) count_parameters, check_update, get_commit_hash)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda() text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() linear_input = linear_input.cuda()
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input) model.forward(text_input, mel_input, text_lengths)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -1,122 +1,127 @@
import os import os
import librosa import librosa
import pickle import pickle
import copy import copy
import numpy as np import numpy as np
from scipy import signal from scipy import signal
_mel_basis = None _mel_basis = None
class AudioProcessor(object): class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None): min_mel_freq, max_mel_freq, griffin_lim_iters=None):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
self.min_level_db = min_level_db self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis self.preemphasis = preemphasis
self.ref_level_db = ref_level_db self.ref_level_db = ref_level_db
self.num_freq = num_freq self.num_freq = num_freq
self.power = power self.power = power
self.griffin_lim_iters = griffin_lim_iters self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
def save_wav(self, wav, path): self.griffin_lim_iters = griffin_lim_iters
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
def _linear_to_mel(self, spectrogram): librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
global _mel_basis
if _mel_basis is None: def _linear_to_mel(self, spectrogram):
_mel_basis = self._build_mel_basis() global _mel_basis
return np.dot(_mel_basis, spectrogram) if _mel_basis is None:
_mel_basis = self._build_mel_basis()
def _build_mel_basis(self, ): return np.dot(_mel_basis, spectrogram)
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
def _normalize(self, S): return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _denormalize(self, S): def _normalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _stft_parameters(self, ): def _denormalize(self, S):
n_fft = (self.num_freq - 1) * 2 return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) def _stft_parameters(self, ):
return n_fft, hop_length, win_length n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
def _amp_to_db(self, x): win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
return 20 * np.log10(np.maximum(1e-5, x)) return n_fft, hop_length, win_length
def _db_to_amp(self, x): def _amp_to_db(self, x):
return np.power(10.0, x * 0.05) return 20 * np.log10(np.maximum(1e-5, x))
def apply_preemphasis(self, x): def _db_to_amp(self, x):
return signal.lfilter([1, -self.preemphasis], [1], x) return np.power(10.0, x * 0.05)
def apply_inv_preemphasis(self, x): def apply_preemphasis(self, x):
return signal.lfilter([1], [1, -self.preemphasis], x) return signal.lfilter([1, -self.preemphasis], [1], x)
def spectrogram(self, y): def apply_inv_preemphasis(self, x):
D = self._stft(self.apply_preemphasis(y)) return signal.lfilter([1], [1, -self.preemphasis], x)
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S) def spectrogram(self, y):
# D = self._stft(self.apply_preemphasis(y))
def inv_spectrogram(self, spectrogram): D = self._stft(y)
'''Converts spectrogram to waveform using librosa''' S = self._amp_to_db(np.abs(D)) - self.ref_level_db
S = self._denormalize(spectrogram) return self._normalize(S)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase def inv_spectrogram(self, spectrogram):
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) '''Converts spectrogram to waveform using librosa'''
S = self._denormalize(spectrogram)
# def _griffin_lim(self, S): S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# '''librosa implementation of Griffin-Lim # Reconstruct phase
# Based on https://github.com/librosa/librosa/issues/434 # return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
# ''' return self._griffin_lim(S ** self.power)
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
# S_complex = np.abs(S).astype(np.complex) # def _griffin_lim(self, S):
# y = self._istft(S_complex * angles) # '''librosa implementation of Griffin-Lim
# for i in range(self.griffin_lim_iters): # Based on https://github.com/librosa/librosa/issues/434
# angles = np.exp(1j * np.angle(self._stft(y))) # '''
# y = self._istft(S_complex * angles) # angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
# return y # S_complex = np.abs(S).astype(np.complex)
# y = self._istft(S_complex * angles)
def _griffin_lim(self, S): # for i in range(self.griffin_lim_iters):
'''Applies Griffin-Lim's raw. # angles = np.exp(1j * np.angle(self._stft(y)))
''' # y = self._istft(S_complex * angles)
S_best = copy.deepcopy(S) # return y
for i in range(self.griffin_lim_iters):
S_t = self._istft(S_best) def _griffin_lim(self, S):
est = self._stft(S_t) '''Applies Griffin-Lim's raw.
phase = est / np.maximum(1e-8, np.abs(est)) '''
S_best = S * phase S_best = copy.deepcopy(S)
S_t = self._istft(S_best) for i in range(self.griffin_lim_iters):
y = np.real(S_t) S_t = self._istft(S_best)
return y est = self._stft(S_t)
phase = est / np.maximum(1e-8, np.abs(est))
def melspectrogram(self, y): S_best = S * phase
D = self._stft(self.apply_preemphasis(y)) S_t = self._istft(S_best)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db y = np.real(S_t)
return self._normalize(S) return y
def _stft(self, y): def melspectrogram(self, y):
n_fft, hop_length, win_length = self._stft_parameters() D = self._stft(self.apply_preemphasis(y))
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters() def _stft(self, y):
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann') n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec) def _istft(self, y):
hop_length = int(window_length / 4) _, hop_length, win_length = self._stft_parameters()
threshold = self._db_to_amp(threshold_db) return librosa.istft(y, hop_length=hop_length, win_length=win_length)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold: def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
return x + hop_length window_length = int(self.sample_rate * min_silence_sec)
return len(wav) hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)