mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'graves-discretev2' into dev
This commit is contained in:
commit
a77f6e5d91
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"model": "Tacotron2", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"run_name": "ljspeech-graves",
|
"run_name": "ljspeech-gravesv2",
|
||||||
"run_description": "tacotron2 wuth graves attention",
|
"run_description": "tacotron2 wuth graves attention",
|
||||||
|
|
||||||
// AUDIO PARAMETERS
|
// AUDIO PARAMETERS
|
||||||
|
@ -109,7 +109,7 @@
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"name": "ljspeech",
|
||||||
"path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/",
|
"path": "/root/LJSpeech-1.1/",
|
||||||
// "path": "/home/erogol/Data/LJSpeech-1.1",
|
// "path": "/home/erogol/Data/LJSpeech-1.1",
|
||||||
"meta_file_train": "metadata_train.csv",
|
"meta_file_train": "metadata_train.csv",
|
||||||
"meta_file_val": "metadata_val.csv"
|
"meta_file_val": "metadata_val.csv"
|
||||||
|
|
|
@ -110,6 +110,85 @@ class LocationLayer(nn.Module):
|
||||||
return processed_attention
|
return processed_attention
|
||||||
|
|
||||||
|
|
||||||
|
class GravesAttention(nn.Module):
|
||||||
|
""" Graves attention as described here:
|
||||||
|
- https://arxiv.org/abs/1910.10288
|
||||||
|
"""
|
||||||
|
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||||
|
|
||||||
|
def __init__(self, query_dim, K):
|
||||||
|
super(GravesAttention, self).__init__()
|
||||||
|
self._mask_value = 1e-8
|
||||||
|
self.K = K
|
||||||
|
# self.attention_alignment = 0.05
|
||||||
|
self.eps = 1e-5
|
||||||
|
self.J = None
|
||||||
|
self.N_a = nn.Sequential(
|
||||||
|
nn.Linear(query_dim, query_dim, bias=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(query_dim, 3*K, bias=True))
|
||||||
|
self.attention_weights = None
|
||||||
|
self.mu_prev = None
|
||||||
|
self.init_layers()
|
||||||
|
|
||||||
|
def init_layers(self):
|
||||||
|
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean
|
||||||
|
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std
|
||||||
|
|
||||||
|
def init_states(self, inputs):
|
||||||
|
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||||
|
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
||||||
|
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||||
|
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||||
|
|
||||||
|
# pylint: disable=R0201
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def preprocess_inputs(self, inputs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward(self, query, inputs, processed_inputs, mask):
|
||||||
|
"""
|
||||||
|
shapes:
|
||||||
|
query: B x D_attention_rnn
|
||||||
|
inputs: B x T_in x D_encoder
|
||||||
|
processed_inputs: place_holder
|
||||||
|
mask: B x T_in
|
||||||
|
"""
|
||||||
|
gbk_t = self.N_a(query)
|
||||||
|
gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K)
|
||||||
|
|
||||||
|
# attention model parameters
|
||||||
|
# each B x K
|
||||||
|
g_t = gbk_t[:, 0, :]
|
||||||
|
b_t = gbk_t[:, 1, :]
|
||||||
|
k_t = gbk_t[:, 2, :]
|
||||||
|
|
||||||
|
# attention GMM parameters
|
||||||
|
sig_t = torch.nn.functional.softplus(b_t) + self.eps
|
||||||
|
|
||||||
|
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||||
|
g_t = torch.softmax(g_t, dim=-1) + self.eps
|
||||||
|
|
||||||
|
j = self.J[:inputs.size(1)+1]
|
||||||
|
|
||||||
|
# attention weights
|
||||||
|
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))
|
||||||
|
|
||||||
|
# discritize attention weights
|
||||||
|
alpha_t = torch.sum(phi_t, 1)
|
||||||
|
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||||
|
alpha_t[alpha_t == 0] = 1e-8
|
||||||
|
|
||||||
|
# apply masking
|
||||||
|
if mask is not None:
|
||||||
|
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||||
|
|
||||||
|
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||||
|
self.attention_weights = alpha_t
|
||||||
|
self.mu_prev = mu_t
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
class OriginalAttention(nn.Module):
|
class OriginalAttention(nn.Module):
|
||||||
"""Following the methods proposed here:
|
"""Following the methods proposed here:
|
||||||
- https://arxiv.org/abs/1712.05884
|
- https://arxiv.org/abs/1712.05884
|
||||||
|
|
|
@ -66,12 +66,11 @@ class AudioProcessor(object):
|
||||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
|
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
|
||||||
|
|
||||||
def _build_mel_basis(self, ):
|
def _build_mel_basis(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
|
||||||
if self.mel_fmax is not None:
|
if self.mel_fmax is not None:
|
||||||
assert self.mel_fmax <= self.sample_rate // 2
|
assert self.mel_fmax <= self.sample_rate // 2
|
||||||
return librosa.filters.mel(
|
return librosa.filters.mel(
|
||||||
self.sample_rate,
|
self.sample_rate,
|
||||||
n_fft,
|
self.n_fft,
|
||||||
n_mels=self.num_mels,
|
n_mels=self.num_mels,
|
||||||
fmin=self.mel_fmin,
|
fmin=self.mel_fmin,
|
||||||
fmax=self.mel_fmax)
|
fmax=self.mel_fmax)
|
||||||
|
@ -197,6 +196,7 @@ class AudioProcessor(object):
|
||||||
n_fft=self.n_fft,
|
n_fft=self.n_fft,
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
|
pad_mode='constant'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _istft(self, y):
|
def _istft(self, y):
|
||||||
|
@ -217,7 +217,7 @@ class AudioProcessor(object):
|
||||||
margin = int(self.sample_rate * 0.01)
|
margin = int(self.sample_rate * 0.01)
|
||||||
wav = wav[margin:-margin]
|
wav = wav[margin:-margin]
|
||||||
return librosa.effects.trim(
|
return librosa.effects.trim(
|
||||||
wav, top_db=60, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
wav, top_db=40, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mulaw_encode(wav, qc):
|
def mulaw_encode(wav, qc):
|
||||||
|
|
Loading…
Reference in New Issue