mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'loc-sens-attn' into loc-sens-attn-new and attention without attention-cum
This commit is contained in:
commit
c381c730c4
|
@ -23,6 +23,7 @@ Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the
|
||||||
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
||||||
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
||||||
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
||||||
|
| Best: [iter-K] | [bla]() | [link]() | Location Sensitive attention |
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
Currently TTS provides data loaders for
|
Currently TTS provides data loaders for
|
||||||
|
|
|
@ -13,8 +13,8 @@ class BahdanauAttention(nn.Module):
|
||||||
def forward(self, annots, query):
|
def forward(self, annots, query):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- query: (batch, 1, dim) or (batch, dim)
|
|
||||||
- annots: (batch, max_time, dim)
|
- annots: (batch, max_time, dim)
|
||||||
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
"""
|
"""
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
|
@ -29,31 +29,70 @@ class BahdanauAttention(nn.Module):
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def get_mask_from_lengths(inputs, inputs_lengths):
|
class LocationSensitiveAttention(nn.Module):
|
||||||
"""Get mask tensor from list of length
|
"""Location sensitive attention following
|
||||||
|
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||||
|
def __init__(self, annot_dim, query_dim, hidden_dim,
|
||||||
|
kernel_size=7, filters=20):
|
||||||
|
super(LocationSensitiveAttention, self).__init__()
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.filters = filters
|
||||||
|
padding = int((kernel - 1) / 2)
|
||||||
|
self.loc_conv = nn.Conv1d(2, filters,
|
||||||
|
kernel_size=kernel_size, stride=1,
|
||||||
|
padding=padding, bias=False)
|
||||||
|
self.loc_linear = nn.Linear(loc_dim, hidden_dim)
|
||||||
|
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||||
|
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||||
|
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||||
|
|
||||||
Args:
|
def forward(self, annot, query, loc):
|
||||||
inputs: Tensor in size (batch, max_time, dim)
|
"""
|
||||||
inputs_lengths: array like
|
Shapes:
|
||||||
"""
|
- annot: (batch, max_time, dim)
|
||||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
for idx, l in enumerate(inputs_lengths):
|
- loc: (batch, 2, max_time)
|
||||||
mask[idx][:l] = 1
|
"""
|
||||||
return ~mask
|
if query.dim() == 2:
|
||||||
|
# insert time-axis for broadcasting
|
||||||
|
query = query.unsqueeze(1)
|
||||||
|
loc_conv = self.loc_conv(loc)
|
||||||
|
loc_conv = loc_conv.transpose(1, 2)
|
||||||
|
processed_loc = self.loc_linear(loc_conv)
|
||||||
|
processed_query = self.query_layer(query)
|
||||||
|
processed_annots = self.annot_layer(annot)
|
||||||
|
alignment = self.v(nn.functional.tanh(
|
||||||
|
processed_query + processed_annots + processed_loc))
|
||||||
|
# (batch, max_time)
|
||||||
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNN(nn.Module):
|
class AttentionRNNCell(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim,
|
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||||
score_mask_value=-float("inf")):
|
r"""
|
||||||
|
General Attention RNN wrapper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dim (int): context vector feature dimension.
|
||||||
|
annot_dim (int): annotation vector feature dimension.
|
||||||
|
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||||
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
|
"""
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
|
self.align_model = align_model
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
# pick bahdanau or location sensitive attention
|
||||||
self.score_mask_value = score_mask_value
|
if align_model == 'b':
|
||||||
|
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||||
|
if align_model == 'ls':
|
||||||
|
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||||
|
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, memory, context, rnn_state, annotations,
|
def forward(self, memory, context, rnn_state, annotations,
|
||||||
mask=None, annotations_lengths=None):
|
attention_vec, mask=None, annotations_lengths=None):
|
||||||
if annotations_lengths is not None and mask is None:
|
|
||||||
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
|
@ -62,7 +101,10 @@ class AttentionRNN(nn.Module):
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
alignment = self.alignment_model(annotations, rnn_output)
|
if self.align_model is 'b':
|
||||||
|
alignment = self.alignment_model(annotations, rnn_output)
|
||||||
|
else:
|
||||||
|
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||||
# TODO: needs recheck.
|
# TODO: needs recheck.
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(query.size(0), -1)
|
||||||
|
|
|
@ -17,10 +17,10 @@ def _sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
class L1LossMasked(nn.Module):
|
class L2LossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(L1LossMasked, self).__init__()
|
super(L2LossMasked, self).__init__()
|
||||||
|
|
||||||
def forward(self, input, target, length):
|
def forward(self, input, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -44,7 +44,7 @@ class L1LossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
|
||||||
reduce=False)
|
reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .attention import AttentionRNN
|
from .attention import AttentionRNNCell
|
||||||
from .attention import get_mask_from_lengths
|
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
|
@ -12,7 +10,7 @@ class Prenet(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of the input vector
|
in_features (int): size of the input vector
|
||||||
out_features (int or list): size of each output sample.
|
out_features (int or list): size of each output sample.
|
||||||
If it is a list, for each value, there is created a new layer.
|
If it is a list, for each value, there is created a new layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, out_features=[256, 128]):
|
def __init__(self, in_features, out_features=[256, 128]):
|
||||||
|
@ -162,7 +160,7 @@ class CBHG(nn.Module):
|
||||||
x = highway(x)
|
x = highway(x)
|
||||||
# (B, T_in, in_features*2)
|
# (B, T_in, in_features*2)
|
||||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||||
self.gru.flatten_parameters()
|
# self.gru.flatten_parameters()
|
||||||
outputs, _ = self.gru(x)
|
outputs, _ = self.gru(x)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -195,7 +193,6 @@ class Decoder(nn.Module):
|
||||||
in_features (int): input vector (encoder output) sample size.
|
in_features (int): input vector (encoder output) sample size.
|
||||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
eps (float): threshold for detecting the end of a sentence.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r):
|
||||||
|
@ -205,8 +202,8 @@ class Decoder(nn.Module):
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -234,6 +231,7 @@ class Decoder(nn.Module):
|
||||||
- memory: batch x #mel_specs x mel_spec_dim
|
- memory: batch x #mel_specs x mel_spec_dim
|
||||||
"""
|
"""
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
T = inputs.size(1)
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = not self.training
|
greedy = not self.training
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
@ -243,19 +241,22 @@ class Decoder(nn.Module):
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
# go frame - 0 frames tarting the sequence
|
# go frame as zeros matrix
|
||||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||||
# Init decoder states
|
# decoder states
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
|
# attention states
|
||||||
|
attention = inputs.data.new(B, T).zero_()
|
||||||
|
# attention_cum = inputs.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
attentions = []
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
|
@ -268,8 +269,12 @@ class Decoder(nn.Module):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
# attention_cat = torch.cat((attention.unsqueeze(1),
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
# attention_cum.unsqueeze(1)),
|
||||||
|
# dim=1)
|
||||||
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
|
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
|
||||||
|
# attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
@ -286,14 +291,14 @@ class Decoder(nn.Module):
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
@ -301,28 +306,35 @@ class Decoder(nn.Module):
|
||||||
break
|
break
|
||||||
assert greedy or len(outputs) == T_decoder
|
assert greedy or len(outputs) == T_decoder
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, attentions, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Predicting stop-token in decoder.
|
Predicting stop-token in decoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
r (int): number of output frames of the network.
|
r (int): number of output frames of the network.
|
||||||
memory_dim (int): feature dimension for each output frame.
|
memory_dim (int): feature dimension for each output frame.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, r, memory_dim):
|
def __init__(self, r, memory_dim):
|
||||||
|
r"""
|
||||||
|
Predicts the stop token to stop the decoder at testing time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r (int): number of network output frames.
|
||||||
|
memory_dim (int): single feature dim of a single network output frame.
|
||||||
|
"""
|
||||||
super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.linear = nn.Linear(r * memory_dim, 1)
|
self.linear = nn.Linear(r * memory_dim, 1)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, inputs, rnn_hidden):
|
def forward(self, inputs, rnn_hidden):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -333,4 +345,4 @@ class StopNet(nn.Module):
|
||||||
outputs = self.relu(rnn_hidden)
|
outputs = self.relu(rnn_hidden)
|
||||||
outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
outputs = self.sigmoid(outputs)
|
outputs = self.sigmoid(outputs)
|
||||||
return outputs, rnn_hidden
|
return outputs, rnn_hidden
|
||||||
|
|
File diff suppressed because one or more lines are too long
2
train.py
2
train.py
|
@ -26,7 +26,7 @@ from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L2LossMasked
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
Loading…
Reference in New Issue