mirror of https://github.com/coqui-ai/TTS.git
Attn masking
This commit is contained in:
parent
9f52833151
commit
dac8fdffa9
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class BahdanauAttention(nn.Module):
|
||||
|
@ -91,8 +92,17 @@ class AttentionRNNCell(nn.Module):
|
|||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||
|
||||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
def forward(self, memory, context, rnn_state, annots,
|
||||
atten, annot_lens=None):
|
||||
"""
|
||||
Shapes:
|
||||
- memory: (batch, 1, dim) or (batch, dim)
|
||||
- context: (batch, dim)
|
||||
- rnn_state: (batch, out_dim)
|
||||
- annots: (batch, max_time, annot_dim)
|
||||
- atten: (batch, max_time)
|
||||
- annot_lens: (batch,)
|
||||
"""
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
|
@ -102,18 +112,18 @@ class AttentionRNNCell(nn.Module):
|
|||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
if self.align_model is 'b':
|
||||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
alignment = self.alignment_model(annots, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
# TODO: needs recheck.
|
||||
if mask is not None:
|
||||
mask = mask.view(query.size(0), -1)
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||
if annot_lens is not None:
|
||||
mask = sequence_mask(annot_lens)
|
||||
mask = mask.view(memory.size(0), -1)
|
||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Normalize context weight
|
||||
alignment = F.softmax(alignment, dim=-1)
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
||||
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||
context = context.squeeze(1)
|
||||
return rnn_output, context, alignment
|
||||
|
|
|
@ -1,26 +1,53 @@
|
|||
import torch
|
||||
from torch.nn import functional
|
||||
from torch import nn
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def _sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
class L2LossMasked(nn.Module):
|
||||
class L1LossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(L2LossMasked, self).__init__()
|
||||
super(L1LossMasked, self).__init__()
|
||||
|
||||
def forward(self, input, target, length):
|
||||
"""
|
||||
Args:
|
||||
input: A Variable containing a FloatTensor of size
|
||||
(batch, max_len, dim) which contains the
|
||||
unnormalized probability for each class.
|
||||
target: A Variable containing a LongTensor of size
|
||||
(batch, max_len, dim) which contains the index of the true
|
||||
class for each corresponding step.
|
||||
length: A Variable containing a LongTensor of size (batch,)
|
||||
which contains the length of each data in a batch.
|
||||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
input = input.contiguous()
|
||||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.shape[-1])
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MSELossMasked, self).__init__()
|
||||
|
||||
def forward(self, input, target, length):
|
||||
"""
|
||||
|
@ -48,9 +75,9 @@ class L2LossMasked(nn.Module):
|
|||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = _sequence_mask(sequence_length=length,
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
|
|
|
@ -213,7 +213,7 @@ class Decoder(nn.Module):
|
|||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.stopnet = StopNet(r, memory_dim)
|
||||
|
||||
def forward(self, inputs, memory=None):
|
||||
def forward(self, inputs, memory=None, input_lens=None):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
||||
|
@ -225,6 +225,7 @@ class Decoder(nn.Module):
|
|||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||
output as the input.
|
||||
input_lens (None): Time length of each input in batch.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
|
@ -273,7 +274,8 @@ class Decoder(nn.Module):
|
|||
# attention_cum.unsqueeze(1)),
|
||||
# dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention, input_lens)
|
||||
# attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
|
|
|
@ -21,14 +21,14 @@ class Tacotron(nn.Module):
|
|||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||
|
||||
def forward(self, characters, mel_specs=None):
|
||||
def forward(self, characters, mel_specs=None, text_lens=None):
|
||||
B = characters.size(0)
|
||||
inputs = self.embedding(characters)
|
||||
# batch x time x dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs)
|
||||
encoder_outputs, mel_specs, text_lens)
|
||||
# Reshape
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
|
|
4
train.py
4
train.py
|
@ -22,7 +22,6 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
|||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters, check_update, get_commit_hash)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
|
@ -96,6 +95,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
text_lengths = text_lengths.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
@ -103,7 +103,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_input)
|
||||
model.forward(text_input, mel_input, text_lengths)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
249
utils/audio.py
249
utils/audio.py
|
@ -1,122 +1,127 @@
|
|||
import os
|
||||
import librosa
|
||||
import pickle
|
||||
import copy
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
_mel_basis = None
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
|
||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||
griffin_lim_iters=None):
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
self.min_level_db = min_level_db
|
||||
self.frame_shift_ms = frame_shift_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
self.preemphasis = preemphasis
|
||||
self.ref_level_db = ref_level_db
|
||||
self.num_freq = num_freq
|
||||
self.power = power
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = self._build_mel_basis()
|
||||
return np.dot(_mel_basis, spectrogram)
|
||||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||
|
||||
def _normalize(self, S):
|
||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
|
||||
def _denormalize(self, S):
|
||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||
|
||||
def _stft_parameters(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||
return n_fft, hop_length, win_length
|
||||
|
||||
def _amp_to_db(self, x):
|
||||
return 20 * np.log10(np.maximum(1e-5, x))
|
||||
|
||||
def _db_to_amp(self, x):
|
||||
return np.power(10.0, x * 0.05)
|
||||
|
||||
def apply_preemphasis(self, x):
|
||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
|
||||
def apply_inv_preemphasis(self, x):
|
||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
|
||||
def spectrogram(self, y):
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def inv_spectrogram(self, spectrogram):
|
||||
'''Converts spectrogram to waveform using librosa'''
|
||||
S = self._denormalize(spectrogram)
|
||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
# Reconstruct phase
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
|
||||
# def _griffin_lim(self, S):
|
||||
# '''librosa implementation of Griffin-Lim
|
||||
# Based on https://github.com/librosa/librosa/issues/434
|
||||
# '''
|
||||
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
# S_complex = np.abs(S).astype(np.complex)
|
||||
# y = self._istft(S_complex * angles)
|
||||
# for i in range(self.griffin_lim_iters):
|
||||
# angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
# y = self._istft(S_complex * angles)
|
||||
# return y
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
'''Applies Griffin-Lim's raw.
|
||||
'''
|
||||
S_best = copy.deepcopy(S)
|
||||
for i in range(self.griffin_lim_iters):
|
||||
S_t = self._istft(S_best)
|
||||
est = self._stft(S_t)
|
||||
phase = est / np.maximum(1e-8, np.abs(est))
|
||||
S_best = S * phase
|
||||
S_t = self._istft(S_best)
|
||||
y = np.real(S_t)
|
||||
return y
|
||||
|
||||
def melspectrogram(self, y):
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def _stft(self, y):
|
||||
n_fft, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
def _istft(self, y):
|
||||
_, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
||||
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(threshold_db)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x:x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
import os
|
||||
import librosa
|
||||
import pickle
|
||||
import copy
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
_mel_basis = None
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
|
||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
self.min_level_db = min_level_db
|
||||
self.frame_shift_ms = frame_shift_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
self.preemphasis = preemphasis
|
||||
self.ref_level_db = ref_level_db
|
||||
self.num_freq = num_freq
|
||||
self.power = power
|
||||
self.min_mel_freq = min_mel_freq
|
||||
self.max_mel_freq = max_mel_freq
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = self._build_mel_basis()
|
||||
return np.dot(_mel_basis, spectrogram)
|
||||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels,
|
||||
fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||
|
||||
def _normalize(self, S):
|
||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
|
||||
def _denormalize(self, S):
|
||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||
|
||||
def _stft_parameters(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||
return n_fft, hop_length, win_length
|
||||
|
||||
def _amp_to_db(self, x):
|
||||
return 20 * np.log10(np.maximum(1e-5, x))
|
||||
|
||||
def _db_to_amp(self, x):
|
||||
return np.power(10.0, x * 0.05)
|
||||
|
||||
def apply_preemphasis(self, x):
|
||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
|
||||
def apply_inv_preemphasis(self, x):
|
||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
|
||||
def spectrogram(self, y):
|
||||
# D = self._stft(self.apply_preemphasis(y))
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def inv_spectrogram(self, spectrogram):
|
||||
'''Converts spectrogram to waveform using librosa'''
|
||||
S = self._denormalize(spectrogram)
|
||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
# Reconstruct phase
|
||||
# return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
# def _griffin_lim(self, S):
|
||||
# '''librosa implementation of Griffin-Lim
|
||||
# Based on https://github.com/librosa/librosa/issues/434
|
||||
# '''
|
||||
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
# S_complex = np.abs(S).astype(np.complex)
|
||||
# y = self._istft(S_complex * angles)
|
||||
# for i in range(self.griffin_lim_iters):
|
||||
# angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
# y = self._istft(S_complex * angles)
|
||||
# return y
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
'''Applies Griffin-Lim's raw.
|
||||
'''
|
||||
S_best = copy.deepcopy(S)
|
||||
for i in range(self.griffin_lim_iters):
|
||||
S_t = self._istft(S_best)
|
||||
est = self._stft(S_t)
|
||||
phase = est / np.maximum(1e-8, np.abs(est))
|
||||
S_best = S * phase
|
||||
S_t = self._istft(S_best)
|
||||
y = np.real(S_t)
|
||||
return y
|
||||
|
||||
def melspectrogram(self, y):
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def _stft(self, y):
|
||||
n_fft, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
def _istft(self, y):
|
||||
_, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(threshold_db)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x:x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
|
Loading…
Reference in New Issue