mirror of https://github.com/coqui-ai/TTS.git
Attn masking
This commit is contained in:
parent
9f52833151
commit
dac8fdffa9
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class BahdanauAttention(nn.Module):
|
||||
|
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
|
|||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||
|
||||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
def forward(self, memory, context, rnn_state, annots,
|
||||
atten, annot_lens=None):
|
||||
"""
|
||||
Shapes:
|
||||
- memory: (batch, 1, dim) or (batch, dim)
|
||||
- context: (batch, dim)
|
||||
- rnn_state: (batch, out_dim)
|
||||
- annots: (batch, max_time, annot_dim)
|
||||
- atten: (batch, max_time)
|
||||
- annot_lens: (batch,)
|
||||
"""
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
|
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
|
|||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
if self.align_model is 'b':
|
||||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
alignment = self.alignment_model(annots, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
# TODO: needs recheck.
|
||||
if mask is not None:
|
||||
mask = mask.view(query.size(0), -1)
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||
if annot_lens is not None:
|
||||
mask = sequence_mask(annot_lens)
|
||||
mask = mask.view(memory.size(0), -1)
|
||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Normalize context weight
|
||||
alignment = F.softmax(alignment, dim=-1)
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
||||
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||
context = context.squeeze(1)
|
||||
return rnn_output, context, alignment
|
||||
|
|
|
@ -1,26 +1,53 @@
|
|||
import torch
|
||||
from torch.nn import functional
|
||||
from torch import nn
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def _sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
class L2LossMasked(nn.Module):
|
||||
class L1LossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(L2LossMasked, self).__init__()
|
||||
super(L1LossMasked, self).__init__()
|
||||
|
||||
def forward(self, input, target, length):
|
||||
"""
|
||||
Args:
|
||||
input: A Variable containing a FloatTensor of size
|
||||
(batch, max_len, dim) which contains the
|
||||
unnormalized probability for each class.
|
||||
target: A Variable containing a LongTensor of size
|
||||
(batch, max_len, dim) which contains the index of the true
|
||||
class for each corresponding step.
|
||||
length: A Variable containing a LongTensor of size (batch,)
|
||||
which contains the length of each data in a batch.
|
||||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
input = input.contiguous()
|
||||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.shape[-1])
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MSELossMasked, self).__init__()
|
||||
|
||||
def forward(self, input, target, length):
|
||||
"""
|
||||
|
@ -50,7 +77,7 @@ class L2LossMasked(nn.Module):
|
|||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = _sequence_mask(sequence_length=length,
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
|
|
|
@ -213,7 +213,7 @@ class Decoder(nn.Module):
|
|||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.stopnet = StopNet(r, memory_dim)
|
||||
|
||||
def forward(self, inputs, memory=None):
|
||||
def forward(self, inputs, memory=None, input_lens=None):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
||||
|
@ -225,6 +225,7 @@ class Decoder(nn.Module):
|
|||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||
output as the input.
|
||||
input_lens (None): Time length of each input in batch.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
|
@ -273,7 +274,8 @@ class Decoder(nn.Module):
|
|||
# attention_cum.unsqueeze(1)),
|
||||
# dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention, input_lens)
|
||||
# attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
|
|
|
@ -21,14 +21,14 @@ class Tacotron(nn.Module):
|
|||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||
|
||||
def forward(self, characters, mel_specs=None):
|
||||
def forward(self, characters, mel_specs=None, text_lens=None):
|
||||
B = characters.size(0)
|
||||
inputs = self.embedding(characters)
|
||||
# batch x time x dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs)
|
||||
encoder_outputs, mel_specs, text_lens)
|
||||
# Reshape
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
|
|
4
train.py
4
train.py
|
@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
|||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters, check_update, get_commit_hash)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
|
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
text_lengths = text_lengths.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_input)
|
||||
model.forward(text_input, mel_input, text_lengths)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
|
@ -12,7 +12,7 @@ class AudioProcessor(object):
|
|||
|
||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||
griffin_lim_iters=None):
|
||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
self.min_level_db = min_level_db
|
||||
|
@ -22,6 +22,8 @@ class AudioProcessor(object):
|
|||
self.ref_level_db = ref_level_db
|
||||
self.num_freq = num_freq
|
||||
self.power = power
|
||||
self.min_mel_freq = min_mel_freq
|
||||
self.max_mel_freq = max_mel_freq
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
|
@ -36,7 +38,8 @@ class AudioProcessor(object):
|
|||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
|
||||
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||
|
||||
def _normalize(self, S):
|
||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
|
@ -63,7 +66,8 @@ class AudioProcessor(object):
|
|||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
|
||||
def spectrogram(self, y):
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
# D = self._stft(self.apply_preemphasis(y))
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
|
@ -72,7 +76,8 @@ class AudioProcessor(object):
|
|||
S = self._denormalize(spectrogram)
|
||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
# Reconstruct phase
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
# return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
# def _griffin_lim(self, S):
|
||||
# '''librosa implementation of Griffin-Lim
|
||||
|
@ -110,7 +115,7 @@ class AudioProcessor(object):
|
|||
|
||||
def _istft(self, y):
|
||||
_, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
|
|
Loading…
Reference in New Issue