mirror of https://github.com/coqui-ai/TTS.git
Attn masking
This commit is contained in:
parent
9f52833151
commit
dac8fdffa9
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class BahdanauAttention(nn.Module):
|
class BahdanauAttention(nn.Module):
|
||||||
|
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
|
||||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, memory, context, rnn_state, annotations,
|
def forward(self, memory, context, rnn_state, annots,
|
||||||
attention_vec, mask=None, annotations_lengths=None):
|
atten, annot_lens=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- memory: (batch, 1, dim) or (batch, dim)
|
||||||
|
- context: (batch, dim)
|
||||||
|
- rnn_state: (batch, out_dim)
|
||||||
|
- annots: (batch, max_time, annot_dim)
|
||||||
|
- atten: (batch, max_time)
|
||||||
|
- annot_lens: (batch,)
|
||||||
|
"""
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
|
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
if self.align_model is 'b':
|
if self.align_model is 'b':
|
||||||
alignment = self.alignment_model(annotations, rnn_output)
|
alignment = self.alignment_model(annots, rnn_output)
|
||||||
else:
|
else:
|
||||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||||
# TODO: needs recheck.
|
if annot_lens is not None:
|
||||||
if mask is not None:
|
mask = sequence_mask(annot_lens)
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(memory.size(0), -1)
|
||||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||||
# Normalize context weight
|
# Normalize context weight
|
||||||
alignment = F.softmax(alignment, dim=-1)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return rnn_output, context, alignment
|
return rnn_output, context, alignment
|
||||||
|
|
|
@ -1,26 +1,53 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
class L1LossMasked(nn.Module):
|
||||||
def _sequence_mask(sequence_length, max_len=None):
|
|
||||||
if max_len is None:
|
|
||||||
max_len = sequence_length.data.max()
|
|
||||||
batch_size = sequence_length.size(0)
|
|
||||||
seq_range = torch.arange(0, max_len).long()
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
||||||
if sequence_length.is_cuda:
|
|
||||||
seq_range_expand = seq_range_expand.cuda()
|
|
||||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
|
||||||
.expand_as(seq_range_expand))
|
|
||||||
return seq_range_expand < seq_length_expand
|
|
||||||
|
|
||||||
|
|
||||||
class L2LossMasked(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(L2LossMasked, self).__init__()
|
super(L1LossMasked, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input, target, length):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input: A Variable containing a FloatTensor of size
|
||||||
|
(batch, max_len, dim) which contains the
|
||||||
|
unnormalized probability for each class.
|
||||||
|
target: A Variable containing a LongTensor of size
|
||||||
|
(batch, max_len, dim) which contains the index of the true
|
||||||
|
class for each corresponding step.
|
||||||
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value masked by the length.
|
||||||
|
"""
|
||||||
|
input = input.contiguous()
|
||||||
|
target = target.contiguous()
|
||||||
|
|
||||||
|
# logits_flat: (batch * max_len, dim)
|
||||||
|
input = input.view(-1, input.shape[-1])
|
||||||
|
# target_flat: (batch * max_len, dim)
|
||||||
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
|
# losses_flat: (batch * max_len, dim)
|
||||||
|
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||||
|
reduce=False)
|
||||||
|
# losses: (batch, max_len, dim)
|
||||||
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
mask = sequence_mask(sequence_length=length,
|
||||||
|
max_len=target.size(1)).unsqueeze(2)
|
||||||
|
losses = losses * mask.float()
|
||||||
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class MSELossMasked(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MSELossMasked, self).__init__()
|
||||||
|
|
||||||
def forward(self, input, target, length):
|
def forward(self, input, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -50,7 +77,7 @@ class L2LossMasked(nn.Module):
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = _sequence_mask(sequence_length=length,
|
mask = sequence_mask(sequence_length=length,
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
max_len=target.size(1)).unsqueeze(2)
|
||||||
losses = losses * mask.float()
|
losses = losses * mask.float()
|
||||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
|
|
|
@ -213,7 +213,7 @@ class Decoder(nn.Module):
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.stopnet = StopNet(r, memory_dim)
|
self.stopnet = StopNet(r, memory_dim)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None, input_lens=None):
|
||||||
"""
|
"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -225,6 +225,7 @@ class Decoder(nn.Module):
|
||||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||||
output as the input.
|
output as the input.
|
||||||
|
input_lens (None): Time length of each input in batch.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
|
@ -273,7 +274,8 @@ class Decoder(nn.Module):
|
||||||
# attention_cum.unsqueeze(1)),
|
# attention_cum.unsqueeze(1)),
|
||||||
# dim=1)
|
# dim=1)
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
|
inputs, attention, input_lens)
|
||||||
# attention_cum += attention
|
# attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
|
|
|
@ -21,14 +21,14 @@ class Tacotron(nn.Module):
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||||
|
|
||||||
def forward(self, characters, mel_specs=None):
|
def forward(self, characters, mel_specs=None, text_lens=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# batch x time x dim*r
|
# batch x time x dim*r
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs)
|
encoder_outputs, mel_specs, text_lens)
|
||||||
# Reshape
|
# Reshape
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
4
train.py
4
train.py
|
@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay,
|
save_best_model, load_config, lr_decay,
|
||||||
count_parameters, check_update, get_commit_hash)
|
count_parameters, check_update, get_commit_hash)
|
||||||
from utils.model import get_param_size
|
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input = text_input.cuda()
|
text_input = text_input.cuda()
|
||||||
|
text_lengths = text_lengths.cuda()
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda()
|
linear_input = linear_input.cuda()
|
||||||
|
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_tokens =\
|
mel_output, linear_output, alignments, stop_tokens =\
|
||||||
model.forward(text_input, mel_input)
|
model.forward(text_input, mel_input, text_lengths)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
|
|
|
@ -12,7 +12,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
griffin_lim_iters=None):
|
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.num_mels = num_mels
|
self.num_mels = num_mels
|
||||||
self.min_level_db = min_level_db
|
self.min_level_db = min_level_db
|
||||||
|
@ -22,6 +22,8 @@ class AudioProcessor(object):
|
||||||
self.ref_level_db = ref_level_db
|
self.ref_level_db = ref_level_db
|
||||||
self.num_freq = num_freq
|
self.num_freq = num_freq
|
||||||
self.power = power
|
self.power = power
|
||||||
|
self.min_mel_freq = min_mel_freq
|
||||||
|
self.max_mel_freq = max_mel_freq
|
||||||
self.griffin_lim_iters = griffin_lim_iters
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
|
@ -36,7 +38,8 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _build_mel_basis(self, ):
|
def _build_mel_basis(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
|
||||||
|
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||||
|
|
||||||
def _normalize(self, S):
|
def _normalize(self, S):
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
@ -63,7 +66,8 @@ class AudioProcessor(object):
|
||||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
|
|
||||||
def spectrogram(self, y):
|
def spectrogram(self, y):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
# D = self._stft(self.apply_preemphasis(y))
|
||||||
|
D = self._stft(y)
|
||||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||||
return self._normalize(S)
|
return self._normalize(S)
|
||||||
|
|
||||||
|
@ -72,7 +76,8 @@ class AudioProcessor(object):
|
||||||
S = self._denormalize(spectrogram)
|
S = self._denormalize(spectrogram)
|
||||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
# return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||||
|
return self._griffin_lim(S ** self.power)
|
||||||
|
|
||||||
# def _griffin_lim(self, S):
|
# def _griffin_lim(self, S):
|
||||||
# '''librosa implementation of Griffin-Lim
|
# '''librosa implementation of Griffin-Lim
|
||||||
|
@ -110,7 +115,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _istft(self, y):
|
def _istft(self, y):
|
||||||
_, hop_length, win_length = self._stft_parameters()
|
_, hop_length, win_length = self._stft_parameters()
|
||||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
window_length = int(self.sample_rate * min_silence_sec)
|
window_length = int(self.sample_rate * min_silence_sec)
|
||||||
|
|
Loading…
Reference in New Issue