From dac8fdffa9204d3ef85331c8f6e73998ac4f1ff8 Mon Sep 17 00:00:00 2001 From: Eren G Date: Fri, 13 Jul 2018 14:50:55 +0200 Subject: [PATCH] Attn masking --- layers/attention.py | 28 +++-- layers/losses.py | 63 +++++++---- layers/tacotron.py | 6 +- models/tacotron.py | 4 +- train.py | 4 +- utils/audio.py | 249 ++++++++++++++++++++++---------------------- 6 files changed, 199 insertions(+), 155 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index af1d9e55..19d8924e 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -1,6 +1,7 @@ import torch from torch import nn from torch.nn import functional as F +from utils.generic_utils import sequence_mask class BahdanauAttention(nn.Module): @@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module): 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) - def forward(self, memory, context, rnn_state, annotations, - attention_vec, mask=None, annotations_lengths=None): + def forward(self, memory, context, rnn_state, annots, + atten, annot_lens=None): + """ + Shapes: + - memory: (batch, 1, dim) or (batch, dim) + - context: (batch, dim) + - rnn_state: (batch, out_dim) + - annots: (batch, max_time, annot_dim) + - atten: (batch, max_time) + - annot_lens: (batch,) + """ # Concat input query and previous context context rnn_input = torch.cat((memory, context), -1) # Feed it to RNN @@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module): # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) if self.align_model is 'b': - alignment = self.alignment_model(annotations, rnn_output) + alignment = self.alignment_model(annots, rnn_output) else: - alignment = self.alignment_model(annotations, rnn_output, attention_vec) - # TODO: needs recheck. - if mask is not None: - mask = mask.view(query.size(0), -1) - alignment.data.masked_fill_(mask, self.score_mask_value) + alignment = self.alignment_model(annots, rnn_output, atten) + if annot_lens is not None: + mask = sequence_mask(annot_lens) + mask = mask.view(memory.size(0), -1) + alignment.masked_fill_(1 - mask, -float("inf")) # Normalize context weight alignment = F.softmax(alignment, dim=-1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j - context = torch.bmm(alignment.unsqueeze(1), annotations) + context = torch.bmm(alignment.unsqueeze(1), annots) context = context.squeeze(1) return rnn_output, context, alignment diff --git a/layers/losses.py b/layers/losses.py index a9d393cb..d7b21c38 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -1,26 +1,53 @@ import torch from torch.nn import functional from torch import nn +from utils.generic_utils import sequence_mask -# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 -def _sequence_mask(sequence_length, max_len=None): - if max_len is None: - max_len = sequence_length.data.max() - batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).long() - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.cuda() - seq_length_expand = (sequence_length.unsqueeze(1) - .expand_as(seq_range_expand)) - return seq_range_expand < seq_length_expand - - -class L2LossMasked(nn.Module): +class L1LossMasked(nn.Module): def __init__(self): - super(L2LossMasked, self).__init__() + super(L1LossMasked, self).__init__() + + def forward(self, input, target, length): + """ + Args: + input: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value masked by the length. + """ + input = input.contiguous() + target = target.contiguous() + + # logits_flat: (batch * max_len, dim) + input = input.view(-1, input.shape[-1]) + # target_flat: (batch * max_len, dim) + target_flat = target.view(-1, target.shape[-1]) + # losses_flat: (batch * max_len, dim) + losses_flat = functional.l1_loss(input, target_flat, size_average=False, + reduce=False) + # losses: (batch, max_len, dim) + losses = losses_flat.view(*target.size()) + + # mask: (batch, max_len, 1) + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).unsqueeze(2) + losses = losses * mask.float() + loss = losses.sum() / (length.float().sum() * float(target.shape[2])) + return loss + + +class MSELossMasked(nn.Module): + + def __init__(self): + super(MSELossMasked, self).__init__() def forward(self, input, target, length): """ @@ -48,9 +75,9 @@ class L2LossMasked(nn.Module): reduce=False) # losses: (batch, max_len, dim) losses = losses_flat.view(*target.size()) - + # mask: (batch, max_len, 1) - mask = _sequence_mask(sequence_length=length, + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2) losses = losses * mask.float() loss = losses.sum() / (length.float().sum() * float(target.shape[2])) diff --git a/layers/tacotron.py b/layers/tacotron.py index 78c75011..7f856b33 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -213,7 +213,7 @@ class Decoder(nn.Module): self.proj_to_mel = nn.Linear(256, memory_dim * r) self.stopnet = StopNet(r, memory_dim) - def forward(self, inputs, memory=None): + def forward(self, inputs, memory=None, input_lens=None): """ Decoder forward step. @@ -225,6 +225,7 @@ class Decoder(nn.Module): memory (None): Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. If None, it uses the last output as the input. + input_lens (None): Time length of each input in batch. Shapes: - inputs: batch x time x encoder_out_dim @@ -273,7 +274,8 @@ class Decoder(nn.Module): # attention_cum.unsqueeze(1)), # dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( - processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention) + processed_memory, current_context_vec, attention_rnn_hidden, + inputs, attention, input_lens) # attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( diff --git a/models/tacotron.py b/models/tacotron.py index 1b0923a4..d07bdd6f 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -21,14 +21,14 @@ class Tacotron(nn.Module): self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.last_linear = nn.Linear(mel_dim * 2, linear_dim) - def forward(self, characters, mel_specs=None): + def forward(self, characters, mel_specs=None, text_lens=None): B = characters.size(0) inputs = self.embedding(characters) # batch x time x dim encoder_outputs = self.encoder(inputs) # batch x time x dim*r mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs) + encoder_outputs, mel_specs, text_lens) # Reshape # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim) diff --git a/train.py b/train.py index ea9c6041..113346bb 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, check_update, get_commit_hash) -from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron @@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # dispatch data to GPU if use_cuda: text_input = text_input.cuda() + text_lengths = text_lengths.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() linear_input = linear_input.cuda() @@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # forward pass mel_output, linear_output, alignments, stop_tokens =\ - model.forward(text_input, mel_input) + model.forward(text_input, mel_input, text_lengths) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) diff --git a/utils/audio.py b/utils/audio.py index 9b63d99e..6595b419 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -1,122 +1,127 @@ -import os -import librosa -import pickle -import copy -import numpy as np -from scipy import signal - -_mel_basis = None - - -class AudioProcessor(object): - - def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power, - griffin_lim_iters=None): - self.sample_rate = sample_rate - self.num_mels = num_mels - self.min_level_db = min_level_db - self.frame_shift_ms = frame_shift_ms - self.frame_length_ms = frame_length_ms - self.preemphasis = preemphasis - self.ref_level_db = ref_level_db - self.num_freq = num_freq - self.power = power - self.griffin_lim_iters = griffin_lim_iters - - def save_wav(self, wav, path): - wav *= 32767 / max(0.01, np.max(np.abs(wav))) - librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) - - def _linear_to_mel(self, spectrogram): - global _mel_basis - if _mel_basis is None: - _mel_basis = self._build_mel_basis() - return np.dot(_mel_basis, spectrogram) - - def _build_mel_basis(self, ): - n_fft = (self.num_freq - 1) * 2 - return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) - - def _normalize(self, S): - return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) - - def _denormalize(self, S): - return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db - - def _stft_parameters(self, ): - n_fft = (self.num_freq - 1) * 2 - hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) - win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) - return n_fft, hop_length, win_length - - def _amp_to_db(self, x): - return 20 * np.log10(np.maximum(1e-5, x)) - - def _db_to_amp(self, x): - return np.power(10.0, x * 0.05) - - def apply_preemphasis(self, x): - return signal.lfilter([1, -self.preemphasis], [1], x) - - def apply_inv_preemphasis(self, x): - return signal.lfilter([1], [1, -self.preemphasis], x) - - def spectrogram(self, y): - D = self._stft(self.apply_preemphasis(y)) - S = self._amp_to_db(np.abs(D)) - self.ref_level_db - return self._normalize(S) - - def inv_spectrogram(self, spectrogram): - '''Converts spectrogram to waveform using librosa''' - S = self._denormalize(spectrogram) - S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear - # Reconstruct phase - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) - -# def _griffin_lim(self, S): -# '''librosa implementation of Griffin-Lim -# Based on https://github.com/librosa/librosa/issues/434 -# ''' -# angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) -# S_complex = np.abs(S).astype(np.complex) -# y = self._istft(S_complex * angles) -# for i in range(self.griffin_lim_iters): -# angles = np.exp(1j * np.angle(self._stft(y))) -# y = self._istft(S_complex * angles) -# return y - - def _griffin_lim(self, S): - '''Applies Griffin-Lim's raw. - ''' - S_best = copy.deepcopy(S) - for i in range(self.griffin_lim_iters): - S_t = self._istft(S_best) - est = self._stft(S_t) - phase = est / np.maximum(1e-8, np.abs(est)) - S_best = S * phase - S_t = self._istft(S_best) - y = np.real(S_t) - return y - - def melspectrogram(self, y): - D = self._stft(self.apply_preemphasis(y)) - S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db - return self._normalize(S) - - def _stft(self, y): - n_fft, hop_length, win_length = self._stft_parameters() - return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) - - def _istft(self, y): - _, hop_length, win_length = self._stft_parameters() - return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann') - - def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): - window_length = int(self.sample_rate * min_silence_sec) - hop_length = int(window_length / 4) - threshold = self._db_to_amp(threshold_db) - for x in range(hop_length, len(wav) - window_length, hop_length): - if np.max(wav[x:x + window_length]) < threshold: - return x + hop_length - return len(wav) +import os +import librosa +import pickle +import copy +import numpy as np +from scipy import signal + +_mel_basis = None + + +class AudioProcessor(object): + + def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, + frame_length_ms, preemphasis, ref_level_db, num_freq, power, + min_mel_freq, max_mel_freq, griffin_lim_iters=None): + self.sample_rate = sample_rate + self.num_mels = num_mels + self.min_level_db = min_level_db + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.preemphasis = preemphasis + self.ref_level_db = ref_level_db + self.num_freq = num_freq + self.power = power + self.min_mel_freq = min_mel_freq + self.max_mel_freq = max_mel_freq + self.griffin_lim_iters = griffin_lim_iters + + def save_wav(self, wav, path): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) + + def _linear_to_mel(self, spectrogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = self._build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + def _build_mel_basis(self, ): + n_fft = (self.num_freq - 1) * 2 + return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels, + fmin=self.min_mel_freq, fmax=self.max_mel_freq) + + def _normalize(self, S): + return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) + + def _denormalize(self, S): + return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db + + def _stft_parameters(self, ): + n_fft = (self.num_freq - 1) * 2 + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) + return n_fft, hop_length, win_length + + def _amp_to_db(self, x): + return 20 * np.log10(np.maximum(1e-5, x)) + + def _db_to_amp(self, x): + return np.power(10.0, x * 0.05) + + def apply_preemphasis(self, x): + return signal.lfilter([1, -self.preemphasis], [1], x) + + def apply_inv_preemphasis(self, x): + return signal.lfilter([1], [1, -self.preemphasis], x) + + def spectrogram(self, y): + # D = self._stft(self.apply_preemphasis(y)) + D = self._stft(y) + S = self._amp_to_db(np.abs(D)) - self.ref_level_db + return self._normalize(S) + + def inv_spectrogram(self, spectrogram): + '''Converts spectrogram to waveform using librosa''' + S = self._denormalize(spectrogram) + S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear + # Reconstruct phase + # return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + +# def _griffin_lim(self, S): +# '''librosa implementation of Griffin-Lim +# Based on https://github.com/librosa/librosa/issues/434 +# ''' +# angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) +# S_complex = np.abs(S).astype(np.complex) +# y = self._istft(S_complex * angles) +# for i in range(self.griffin_lim_iters): +# angles = np.exp(1j * np.angle(self._stft(y))) +# y = self._istft(S_complex * angles) +# return y + + def _griffin_lim(self, S): + '''Applies Griffin-Lim's raw. + ''' + S_best = copy.deepcopy(S) + for i in range(self.griffin_lim_iters): + S_t = self._istft(S_best) + est = self._stft(S_t) + phase = est / np.maximum(1e-8, np.abs(est)) + S_best = S * phase + S_t = self._istft(S_best) + y = np.real(S_t) + return y + + def melspectrogram(self, y): + D = self._stft(self.apply_preemphasis(y)) + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db + return self._normalize(S) + + def _stft(self, y): + n_fft, hop_length, win_length = self._stft_parameters() + return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) + + def _istft(self, y): + _, hop_length, win_length = self._stft_parameters() + return librosa.istft(y, hop_length=hop_length, win_length=win_length) + + def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): + window_length = int(self.sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = self._db_to_amp(threshold_db) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x:x + window_length]) < threshold: + return x + hop_length + return len(wav)