mirror of https://github.com/coqui-ai/TTS.git
commention model outputs for tacotron, align outputs shapes of tacotron and tracotron2, merge bidirectional decoder
This commit is contained in:
parent
1c2d26ccb9
commit
e83a4b07d2
|
@ -46,6 +46,7 @@
|
||||||
"forward_attn_mask": false,
|
"forward_attn_mask": false,
|
||||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||||
|
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||||
|
@ -82,7 +83,8 @@
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "ljspeech",
|
"name": "ljspeech",
|
||||||
"path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
|
// "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
|
||||||
|
"path": "/home/erogol/Data/LJSpeech-1.1",
|
||||||
"meta_file_train": "metadata_train.csv",
|
"meta_file_train": "metadata_train.csv",
|
||||||
"meta_file_val": "metadata_val.csv"
|
"meta_file_val": "metadata_val.csv"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .common_layers import Prenet, Attention
|
from .common_layers import Prenet, Attention, Linear
|
||||||
|
|
||||||
|
|
||||||
class BatchNormConv1d(nn.Module):
|
class BatchNormConv1d(nn.Module):
|
||||||
|
@ -125,8 +125,7 @@ class CBHG(nn.Module):
|
||||||
# list of conv1d bank with filter size k=1...K
|
# list of conv1d bank with filter size k=1...K
|
||||||
# TODO: try dilational layers instead
|
# TODO: try dilational layers instead
|
||||||
self.conv1d_banks = nn.ModuleList([
|
self.conv1d_banks = nn.ModuleList([
|
||||||
BatchNormConv1d(
|
BatchNormConv1d(in_features,
|
||||||
in_features,
|
|
||||||
conv_bank_features,
|
conv_bank_features,
|
||||||
kernel_size=k,
|
kernel_size=k,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
@ -142,8 +141,7 @@ class CBHG(nn.Module):
|
||||||
layer_set = []
|
layer_set = []
|
||||||
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
||||||
activations):
|
activations):
|
||||||
layer = BatchNormConv1d(
|
layer = BatchNormConv1d(in_size,
|
||||||
in_size,
|
|
||||||
out_size,
|
out_size,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
@ -153,28 +151,23 @@ class CBHG(nn.Module):
|
||||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||||
# setup Highway layers
|
# setup Highway layers
|
||||||
if self.highway_features != conv_projections[-1]:
|
if self.highway_features != conv_projections[-1]:
|
||||||
self.pre_highway = nn.Linear(
|
self.pre_highway = nn.Linear(conv_projections[-1],
|
||||||
conv_projections[-1], highway_features, bias=False)
|
highway_features,
|
||||||
|
bias=False)
|
||||||
self.highways = nn.ModuleList([
|
self.highways = nn.ModuleList([
|
||||||
Highway(highway_features, highway_features)
|
Highway(highway_features, highway_features)
|
||||||
for _ in range(num_highways)
|
for _ in range(num_highways)
|
||||||
])
|
])
|
||||||
# bi-directional GPU layer
|
# bi-directional GPU layer
|
||||||
self.gru = nn.GRU(
|
self.gru = nn.GRU(gru_features,
|
||||||
gru_features,
|
|
||||||
gru_features,
|
gru_features,
|
||||||
1,
|
1,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
bidirectional=True)
|
bidirectional=True)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
# (B, T_in, in_features)
|
|
||||||
x = inputs
|
|
||||||
# Needed to perform conv1d on time-axis
|
|
||||||
# (B, in_features, T_in)
|
# (B, in_features, T_in)
|
||||||
if x.size(-1) == self.in_features:
|
x = inputs
|
||||||
x = x.transpose(1, 2)
|
|
||||||
# T = x.size(-1)
|
|
||||||
# (B, hid_features*K, T_in)
|
# (B, hid_features*K, T_in)
|
||||||
# Concat conv1d bank outputs
|
# Concat conv1d bank outputs
|
||||||
outs = []
|
outs = []
|
||||||
|
@ -185,10 +178,8 @@ class CBHG(nn.Module):
|
||||||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||||
for conv1d in self.conv1d_projections:
|
for conv1d in self.conv1d_projections:
|
||||||
x = conv1d(x)
|
x = conv1d(x)
|
||||||
# (B, T_in, hid_feature)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
# Back to the original shape
|
|
||||||
x += inputs
|
x += inputs
|
||||||
|
x = x.transpose(1, 2)
|
||||||
if self.highway_features != self.conv_projections[-1]:
|
if self.highway_features != self.conv_projections[-1]:
|
||||||
x = self.pre_highway(x)
|
x = self.pre_highway(x)
|
||||||
# Residual connection
|
# Residual connection
|
||||||
|
@ -236,8 +227,10 @@ class Encoder(nn.Module):
|
||||||
- inputs: batch x time x in_features
|
- inputs: batch x time x in_features
|
||||||
- outputs: batch x time x 128*2
|
- outputs: batch x time x 128*2
|
||||||
"""
|
"""
|
||||||
inputs = self.prenet(inputs)
|
# B x T x prenet_dim
|
||||||
return self.cbhg(inputs)
|
outputs = self.prenet(inputs)
|
||||||
|
outputs = self.cbhg(outputs.transpose(1, 2))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class PostCBHG(nn.Module):
|
class PostCBHG(nn.Module):
|
||||||
|
@ -314,7 +307,12 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
self.stopnet = nn.Sequential(
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
Linear(256 + memory_dim * self.r_init,
|
||||||
|
1,
|
||||||
|
bias=True,
|
||||||
|
init_gain='sigmoid'))
|
||||||
|
|
||||||
def set_r(self, new_r):
|
def set_r(self, new_r):
|
||||||
self.r = new_r
|
self.r = new_r
|
||||||
|
@ -356,8 +354,9 @@ class Decoder(nn.Module):
|
||||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
attentions = torch.stack(attentions).transpose(0, 1)
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||||||
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
|
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
||||||
return outputs, attentions, stop_tokens
|
return outputs, attentions, stop_tokens
|
||||||
|
|
||||||
def decode(self, inputs, mask=None):
|
def decode(self, inputs, mask=None):
|
||||||
|
@ -438,9 +437,8 @@ class Decoder(nn.Module):
|
||||||
output, stop_token, attention = self.decode(inputs, mask)
|
output, stop_token, attention = self.decode(inputs, mask)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||||
|
|
||||||
def inference(self, inputs, speaker_embeddings=None):
|
def inference(self, inputs, speaker_embeddings=None):
|
||||||
|
@ -481,20 +479,20 @@ class Decoder(nn.Module):
|
||||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||||
|
|
||||||
|
|
||||||
class StopNet(nn.Module):
|
# class StopNet(nn.Module):
|
||||||
r"""
|
# r"""
|
||||||
Args:
|
# Args:
|
||||||
in_features (int): feature dimension of input.
|
# in_features (int): feature dimension of input.
|
||||||
"""
|
# """
|
||||||
|
|
||||||
def __init__(self, in_features):
|
# def __init__(self, in_features):
|
||||||
super(StopNet, self).__init__()
|
# super(StopNet, self).__init__()
|
||||||
self.dropout = nn.Dropout(0.1)
|
# self.dropout = nn.Dropout(0.1)
|
||||||
self.linear = nn.Linear(in_features, 1)
|
# self.linear = nn.Linear(in_features, 1)
|
||||||
torch.nn.init.xavier_uniform_(
|
# torch.nn.init.xavier_uniform_(
|
||||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
# self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||||
|
|
||||||
def forward(self, inputs):
|
# def forward(self, inputs):
|
||||||
outputs = self.dropout(inputs)
|
# outputs = self.dropout(inputs)
|
||||||
outputs = self.linear(outputs)
|
# outputs = self.linear(outputs)
|
||||||
return outputs
|
# return outputs
|
||||||
|
|
|
@ -98,11 +98,12 @@ class Encoder(nn.Module):
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
# Pylint gets confused by PyTorch conventions here
|
# Pylint gets confused by PyTorch conventions here
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
|
||||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||||
forward_attn_mask, location_attn, separate_stopnet):
|
forward_attn_mask, location_attn, separate_stopnet,
|
||||||
|
speaker_embedding_dim):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mel_channels = inputs_dim
|
self.memory_dim = memory_dim
|
||||||
self.r_init = r
|
self.r_init = r
|
||||||
self.r = r
|
self.r = r
|
||||||
self.encoder_embedding_dim = in_features
|
self.encoder_embedding_dim = in_features
|
||||||
|
@ -114,10 +115,14 @@ class Decoder(nn.Module):
|
||||||
self.gate_threshold = 0.5
|
self.gate_threshold = 0.5
|
||||||
self.p_attention_dropout = 0.1
|
self.p_attention_dropout = 0.1
|
||||||
self.p_decoder_dropout = 0.1
|
self.p_decoder_dropout = 0.1
|
||||||
self.prenet = Prenet(self.mel_channels,
|
|
||||||
|
# memory -> |Prenet| -> processed_memory
|
||||||
|
prenet_dim = self.memory_dim + speaker_embedding_dim
|
||||||
|
self.prenet = Prenet(
|
||||||
|
prenet_dim,
|
||||||
prenet_type,
|
prenet_type,
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim],
|
out_features=[self.prenet_dim, self.prenet_dim],
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||||
|
@ -139,11 +144,11 @@ class Decoder(nn.Module):
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||||
self.mel_channels * self.r_init)
|
self.memory_dim * self.r_init)
|
||||||
|
|
||||||
self.stopnet = nn.Sequential(
|
self.stopnet = nn.Sequential(
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init,
|
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
|
||||||
1,
|
1,
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
@ -155,7 +160,7 @@ class Decoder(nn.Module):
|
||||||
def get_go_frame(self, inputs):
|
def get_go_frame(self, inputs):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||||
self.mel_channels * self.r)
|
self.memory_dim * self.r)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def _init_states(self, inputs, mask, keep_states=False):
|
def _init_states(self, inputs, mask, keep_states=False):
|
||||||
|
@ -185,16 +190,14 @@ class Decoder(nn.Module):
|
||||||
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
stop_tokens = stop_tokens.contiguous()
|
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
outputs = outputs.view(outputs.size(0), -1, self.mel_channels)
|
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
||||||
outputs = outputs.transpose(1, 2)
|
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def _update_memory(self, memory):
|
def _update_memory(self, memory):
|
||||||
if len(memory.shape) == 2:
|
if len(memory.shape) == 2:
|
||||||
return memory[:, self.mel_channels * (self.r - 1):]
|
return memory[:, self.memory_dim * (self.r - 1):]
|
||||||
return memory[:, :, self.mel_channels * (self.r - 1):]
|
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
query_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
|
@ -228,10 +231,10 @@ class Decoder(nn.Module):
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
decoder_output = decoder_output[:, :self.r * self.mel_channels]
|
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||||
return decoder_output, stop_token, self.attention.attention_weights
|
return decoder_output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask):
|
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
memories = self._reshape_memory(memories)
|
memories = self._reshape_memory(memories)
|
||||||
memories = torch.cat((memory, memories), dim=0)
|
memories = torch.cat((memory, memories), dim=0)
|
||||||
|
@ -243,6 +246,8 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
memory = memories[len(outputs)]
|
memory = memories[len(outputs)]
|
||||||
|
if speaker_embeddings is not None:
|
||||||
|
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||||
mel_output, stop_token, attention_weights = self.decode(memory)
|
mel_output, stop_token, attention_weights = self.decode(memory)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token.squeeze(1)]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
|
@ -253,7 +258,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def inference(self, inputs):
|
def inference(self, inputs, speaker_embeddings=None):
|
||||||
memory = self.get_go_frame(inputs)
|
memory = self.get_go_frame(inputs)
|
||||||
memory = self._update_memory(memory)
|
memory = self._update_memory(memory)
|
||||||
|
|
||||||
|
@ -266,6 +271,8 @@ class Decoder(nn.Module):
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
|
if speaker_embeddings is not None:
|
||||||
|
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
|
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
|
||||||
from TTS.utils.generic_utils import sequence_mask
|
from TTS.utils.generic_utils import sequence_mask
|
||||||
|
@ -11,8 +12,8 @@ class Tacotron(nn.Module):
|
||||||
num_chars,
|
num_chars,
|
||||||
num_speakers,
|
num_speakers,
|
||||||
r=5,
|
r=5,
|
||||||
linear_dim=1025,
|
postnet_output_dim=1025,
|
||||||
mel_dim=80,
|
decoder_output_dim=80,
|
||||||
memory_size=5,
|
memory_size=5,
|
||||||
attn_win=False,
|
attn_win=False,
|
||||||
gst=False,
|
gst=False,
|
||||||
|
@ -23,28 +24,33 @@ class Tacotron(nn.Module):
|
||||||
trans_agent=False,
|
trans_agent=False,
|
||||||
forward_attn_mask=False,
|
forward_attn_mask=False,
|
||||||
location_attn=True,
|
location_attn=True,
|
||||||
separate_stopnet=True):
|
separate_stopnet=True,
|
||||||
|
bidirectional_decoder=False):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.decoder_output_dim = decoder_output_dim
|
||||||
self.linear_dim = linear_dim
|
self.postnet_output_dim = postnet_output_dim
|
||||||
self.gst = gst
|
self.gst = gst
|
||||||
self.num_speakers = num_speakers
|
self.num_speakers = num_speakers
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.bidirectional_decoder = bidirectional_decoder
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
|
||||||
decoder_dim = 512 if num_speakers > 1 else 256
|
decoder_dim = 512 if num_speakers > 1 else 256
|
||||||
encoder_dim = 512 if num_speakers > 1 else 256
|
encoder_dim = 512 if num_speakers > 1 else 256
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
|
# embedding layer
|
||||||
|
self.embedding = nn.Embedding(num_chars, 256)
|
||||||
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
# boilerplate model
|
# boilerplate model
|
||||||
self.encoder = Encoder(encoder_dim)
|
self.encoder = Encoder(encoder_dim)
|
||||||
self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win,
|
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
forward_attn, trans_agent, forward_attn_mask,
|
forward_attn, trans_agent, forward_attn_mask,
|
||||||
location_attn, separate_stopnet,
|
location_attn, separate_stopnet,
|
||||||
proj_speaker_dim)
|
proj_speaker_dim)
|
||||||
self.postnet = PostCBHG(mel_dim)
|
if self.bidirectional_decoder:
|
||||||
|
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||||
|
self.postnet = PostCBHG(decoder_output_dim)
|
||||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||||
linear_dim)
|
postnet_output_dim)
|
||||||
# speaker embedding layers
|
# speaker embedding layers
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||||
|
@ -82,27 +88,48 @@ class Tacotron(nn.Module):
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- characters: B x T_in
|
||||||
|
- text_lengths: B
|
||||||
|
- mel_specs: B x T_out x D
|
||||||
|
- speaker_ids: B x 1
|
||||||
|
"""
|
||||||
|
self._init_states()
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
mask = sequence_mask(text_lengths).to(characters.device)
|
mask = sequence_mask(text_lengths).to(characters.device)
|
||||||
|
# B x T_in x embed_dim
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
self._init_states()
|
# B x speaker_embed_dim
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
self.compute_speaker_embedding(speaker_ids)
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
|
# B x T_in x embed_dim + speaker_embed_dim
|
||||||
inputs = self._concat_speaker_embedding(inputs,
|
inputs = self._concat_speaker_embedding(inputs,
|
||||||
self.speaker_embeddings)
|
self.speaker_embeddings)
|
||||||
|
# B x T_in x encoder_dim
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst:
|
if self.gst:
|
||||||
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
encoder_outputs, self.speaker_embeddings)
|
encoder_outputs, self.speaker_embeddings)
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
# decoder_outputs: B x decoder_dim x T_out
|
||||||
|
# alignments: B x T_in x encoder_dim
|
||||||
|
# stop_tokens: B x T_in
|
||||||
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask,
|
encoder_outputs, mel_specs, mask,
|
||||||
self.speaker_embeddings_projected)
|
self.speaker_embeddings_projected)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
# B x T_out x decoder_dim
|
||||||
linear_outputs = self.postnet(mel_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
# B x T_out x posnet_dim
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
|
# B x T_out x decoder_dim
|
||||||
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
|
if self.bidirectional_decoder:
|
||||||
|
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||||
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||||
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
|
@ -118,12 +145,19 @@ class Tacotron(nn.Module):
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
encoder_outputs, self.speaker_embeddings)
|
encoder_outputs, self.speaker_embeddings)
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs, self.speaker_embeddings_projected)
|
encoder_outputs, self.speaker_embeddings_projected)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim)
|
||||||
linear_outputs = self.postnet(mel_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||||
|
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||||
|
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||||
|
self.speaker_embeddings_projected)
|
||||||
|
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||||
|
return decoder_outputs_b, alignments_b
|
||||||
|
|
||||||
def _compute_speaker_embedding(self, speaker_ids):
|
def _compute_speaker_embedding(self, speaker_ids):
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
|
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||||
|
@ -10,6 +12,8 @@ class Tacotron2(nn.Module):
|
||||||
num_chars,
|
num_chars,
|
||||||
num_speakers,
|
num_speakers,
|
||||||
r,
|
r,
|
||||||
|
postnet_output_dim=80,
|
||||||
|
decoder_output_dim=80,
|
||||||
attn_win=False,
|
attn_win=False,
|
||||||
attn_norm="softmax",
|
attn_norm="softmax",
|
||||||
prenet_type="original",
|
prenet_type="original",
|
||||||
|
@ -18,10 +22,16 @@ class Tacotron2(nn.Module):
|
||||||
trans_agent=False,
|
trans_agent=False,
|
||||||
forward_attn_mask=False,
|
forward_attn_mask=False,
|
||||||
location_attn=True,
|
location_attn=True,
|
||||||
separate_stopnet=True):
|
separate_stopnet=True,
|
||||||
|
bidirectional_decoder=False):
|
||||||
super(Tacotron2, self).__init__()
|
super(Tacotron2, self).__init__()
|
||||||
self.n_mel_channels = 80
|
self.decoder_output_dim = decoder_output_dim
|
||||||
self.n_frames_per_step = r
|
self.n_frames_per_step = r
|
||||||
|
self.bidirectional_decoder = bidirectional_decoder
|
||||||
|
decoder_dim = 512 + 256 if num_speakers > 1 else 512
|
||||||
|
encoder_dim = 512 + 256 if num_speakers > 1 else 512
|
||||||
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 512)
|
self.embedding = nn.Embedding(num_chars, 512)
|
||||||
std = sqrt(2.0 / (num_chars + 512))
|
std = sqrt(2.0 / (num_chars + 512))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
|
@ -29,12 +39,18 @@ class Tacotron2(nn.Module):
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(encoder_dim)
|
||||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
forward_attn, trans_agent, forward_attn_mask,
|
forward_attn, trans_agent, forward_attn_mask,
|
||||||
location_attn, separate_stopnet)
|
location_attn, separate_stopnet, proj_speaker_dim)
|
||||||
self.postnet = Postnet(self.n_mel_channels)
|
if self.bidirectional_decoder:
|
||||||
|
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||||
|
self.postnet = Postnet(self.decoder_output_dim)
|
||||||
|
|
||||||
|
def _init_states(self):
|
||||||
|
self.speaker_embeddings = None
|
||||||
|
self.speaker_embeddings_projected = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
@ -43,19 +59,23 @@ class Tacotron2(nn.Module):
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
|
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
|
||||||
|
self._init_states()
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
mask = sequence_mask(text_lengths).to(text.device)
|
mask = sequence_mask(text_lengths).to(text.device)
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||||
speaker_ids)
|
speaker_ids)
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder(
|
decoder_outputs, stop_tokens, alignments = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask)
|
encoder_outputs, mel_specs, mask)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
if self.bidirectional_decoder:
|
||||||
|
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||||
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||||
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, text, speaker_ids=None):
|
def inference(self, text, speaker_ids=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
|
@ -86,6 +106,13 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
|
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||||
|
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||||
|
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||||
|
self.speaker_embeddings_projected)
|
||||||
|
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||||
|
return decoder_outputs_b, alignments_b
|
||||||
|
|
||||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||||
|
|
24
train.py
24
train.py
|
@ -88,6 +88,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
'avg_loader_time': 0,
|
'avg_loader_time': 0,
|
||||||
'avg_alignment_score': 0
|
'avg_alignment_score': 0
|
||||||
}
|
}
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
|
||||||
|
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
keep_avg.add_values(train_values)
|
keep_avg.add_values(train_values)
|
||||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
|
@ -150,6 +153,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
|
|
||||||
|
@ -174,6 +181,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if not c.separate_stopnet and c.stopnet:
|
if not c.separate_stopnet and c.stopnet:
|
||||||
loss += stop_loss
|
loss += stop_loss
|
||||||
|
|
||||||
|
# backward decoder
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
if c.loss_masking:
|
||||||
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||||
|
else:
|
||||||
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||||
|
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||||
|
loss = decoder_backward_loss + decoder_c_loss
|
||||||
|
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
|
@ -445,7 +462,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
"alignment": plot_alignment(align_img)
|
"alignment": plot_alignment(align_img)
|
||||||
}
|
}
|
||||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
@ -461,7 +477,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
"loss_decoder": keep_avg['avg_decoder_loss'],
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||||
"stop_loss": keep_avg['avg_stop_loss']
|
"stop_loss": keep_avg['avg_stop_loss']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward']
|
||||||
|
epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy()
|
||||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||||
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||||
|
|
||||||
|
|
||||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
# test sentences
|
# test sentences
|
||||||
|
|
|
@ -283,8 +283,8 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
linear_dim=1025,
|
postnet_output_dim=c.audio['num_freq'],
|
||||||
mel_dim=80,
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
gst=c.use_gst,
|
gst=c.use_gst,
|
||||||
memory_size=c.memory_size,
|
memory_size=c.memory_size,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
|
@ -295,11 +295,14 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
trans_agent=c.transition_agent,
|
trans_agent=c.transition_agent,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
location_attn=c.location_attn,
|
location_attn=c.location_attn,
|
||||||
separate_stopnet=c.separate_stopnet)
|
separate_stopnet=c.separate_stopnet,
|
||||||
|
bidirectional_decoder=c.bidirectional_decoder)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
|
postnet_output_dim=c.audio['num_mels'],
|
||||||
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
prenet_type=c.prenet_type,
|
prenet_type=c.prenet_type,
|
||||||
|
@ -308,7 +311,8 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
trans_agent=c.transition_agent,
|
trans_agent=c.transition_agent,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
location_attn=c.location_attn,
|
location_attn=c.location_attn,
|
||||||
separate_stopnet=c.separate_stopnet)
|
separate_stopnet=c.separate_stopnet,
|
||||||
|
bidirectional_decoder=c.bidirectional_decoder)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue