mirror of https://github.com/coqui-ai/TTS.git
commention model outputs for tacotron, align outputs shapes of tacotron and tracotron2, merge bidirectional decoder
This commit is contained in:
parent
1c2d26ccb9
commit
e83a4b07d2
|
@ -46,6 +46,7 @@
|
|||
"forward_attn_mask": false,
|
||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
|
@ -82,7 +83,8 @@
|
|||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
|
||||
// "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1",
|
||||
"meta_file_train": "metadata_train.csv",
|
||||
"meta_file_val": "metadata_val.csv"
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common_layers import Prenet, Attention
|
||||
from .common_layers import Prenet, Attention, Linear
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
|
@ -125,13 +125,12 @@ class CBHG(nn.Module):
|
|||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList([
|
||||
BatchNormConv1d(
|
||||
in_features,
|
||||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=[(k - 1) // 2, k // 2],
|
||||
activation=self.relu) for k in range(1, K + 1)
|
||||
BatchNormConv1d(in_features,
|
||||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=[(k - 1) // 2, k // 2],
|
||||
activation=self.relu) for k in range(1, K + 1)
|
||||
])
|
||||
# max pooling of conv bank, with padding
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
|
@ -142,39 +141,33 @@ class CBHG(nn.Module):
|
|||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
||||
activations):
|
||||
layer = BatchNormConv1d(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=[1, 1],
|
||||
activation=ac)
|
||||
layer = BatchNormConv1d(in_size,
|
||||
out_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=[1, 1],
|
||||
activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
# setup Highway layers
|
||||
if self.highway_features != conv_projections[-1]:
|
||||
self.pre_highway = nn.Linear(
|
||||
conv_projections[-1], highway_features, bias=False)
|
||||
self.pre_highway = nn.Linear(conv_projections[-1],
|
||||
highway_features,
|
||||
bias=False)
|
||||
self.highways = nn.ModuleList([
|
||||
Highway(highway_features, highway_features)
|
||||
for _ in range(num_highways)
|
||||
])
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(
|
||||
gru_features,
|
||||
gru_features,
|
||||
1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
self.gru = nn.GRU(gru_features,
|
||||
gru_features,
|
||||
1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# (B, T_in, in_features)
|
||||
x = inputs
|
||||
# Needed to perform conv1d on time-axis
|
||||
# (B, in_features, T_in)
|
||||
if x.size(-1) == self.in_features:
|
||||
x = x.transpose(1, 2)
|
||||
# T = x.size(-1)
|
||||
x = inputs
|
||||
# (B, hid_features*K, T_in)
|
||||
# Concat conv1d bank outputs
|
||||
outs = []
|
||||
|
@ -185,10 +178,8 @@ class CBHG(nn.Module):
|
|||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
# (B, T_in, hid_feature)
|
||||
x = x.transpose(1, 2)
|
||||
# Back to the original shape
|
||||
x += inputs
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_features != self.conv_projections[-1]:
|
||||
x = self.pre_highway(x)
|
||||
# Residual connection
|
||||
|
@ -236,8 +227,10 @@ class Encoder(nn.Module):
|
|||
- inputs: batch x time x in_features
|
||||
- outputs: batch x time x 128*2
|
||||
"""
|
||||
inputs = self.prenet(inputs)
|
||||
return self.cbhg(inputs)
|
||||
# B x T x prenet_dim
|
||||
outputs = self.prenet(inputs)
|
||||
outputs = self.cbhg(outputs.transpose(1, 2))
|
||||
return outputs
|
||||
|
||||
|
||||
class PostCBHG(nn.Module):
|
||||
|
@ -314,7 +307,12 @@ class Decoder(nn.Module):
|
|||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||
# learn init values instead of zero init.
|
||||
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(256 + memory_dim * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
||||
def set_r(self, new_r):
|
||||
self.r = new_r
|
||||
|
@ -356,8 +354,9 @@ class Decoder(nn.Module):
|
|||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||
# Back to batch first
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
|
||||
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
def decode(self, inputs, mask=None):
|
||||
|
@ -438,9 +437,8 @@ class Decoder(nn.Module):
|
|||
output, stop_token, attention = self.decode(inputs, mask)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
stop_tokens += [stop_token.squeeze(1)]
|
||||
t += 1
|
||||
|
||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
def inference(self, inputs, speaker_embeddings=None):
|
||||
|
@ -481,20 +479,20 @@ class Decoder(nn.Module):
|
|||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
in_features (int): feature dimension of input.
|
||||
"""
|
||||
# class StopNet(nn.Module):
|
||||
# r"""
|
||||
# Args:
|
||||
# in_features (int): feature dimension of input.
|
||||
# """
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(StopNet, self).__init__()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(in_features, 1)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
# def __init__(self, in_features):
|
||||
# super(StopNet, self).__init__()
|
||||
# self.dropout = nn.Dropout(0.1)
|
||||
# self.linear = nn.Linear(in_features, 1)
|
||||
# torch.nn.init.xavier_uniform_(
|
||||
# self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
outputs = self.linear(outputs)
|
||||
return outputs
|
||||
# def forward(self, inputs):
|
||||
# outputs = self.dropout(inputs)
|
||||
# outputs = self.linear(outputs)
|
||||
# return outputs
|
||||
|
|
|
@ -98,11 +98,12 @@ class Encoder(nn.Module):
|
|||
class Decoder(nn.Module):
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
||||
def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, separate_stopnet):
|
||||
forward_attn_mask, location_attn, separate_stopnet,
|
||||
speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.memory_dim = memory_dim
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.encoder_embedding_dim = in_features
|
||||
|
@ -114,11 +115,15 @@ class Decoder(nn.Module):
|
|||
self.gate_threshold = 0.5
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
self.prenet = Prenet(self.mel_channels,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = self.memory_dim + speaker_embedding_dim
|
||||
self.prenet = Prenet(
|
||||
prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
|
@ -139,11 +144,11 @@ class Decoder(nn.Module):
|
|||
self.decoder_rnn_dim, 1)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||
self.mel_channels * self.r_init)
|
||||
self.memory_dim * self.r_init)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init,
|
||||
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
@ -155,7 +160,7 @@ class Decoder(nn.Module):
|
|||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.mel_channels * self.r)
|
||||
self.memory_dim * self.r)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
|
@ -185,16 +190,14 @@ class Decoder(nn.Module):
|
|||
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
stop_tokens = stop_tokens.contiguous()
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(outputs.size(0), -1, self.mel_channels)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def _update_memory(self, memory):
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.mel_channels * (self.r - 1):]
|
||||
return memory[:, :, self.mel_channels * (self.r - 1):]
|
||||
return memory[:, self.memory_dim * (self.r - 1):]
|
||||
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||
|
||||
def decode(self, memory):
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
|
@ -228,10 +231,10 @@ class Decoder(nn.Module):
|
|||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
decoder_output = decoder_output[:, :self.r * self.mel_channels]
|
||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||
return decoder_output, stop_token, self.attention.attention_weights
|
||||
|
||||
def forward(self, inputs, memories, mask):
|
||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
memories = self._reshape_memory(memories)
|
||||
memories = torch.cat((memory, memories), dim=0)
|
||||
|
@ -243,6 +246,8 @@ class Decoder(nn.Module):
|
|||
outputs, stop_tokens, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
memory = memories[len(outputs)]
|
||||
if speaker_embeddings is not None:
|
||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||
mel_output, stop_token, attention_weights = self.decode(memory)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
stop_tokens += [stop_token.squeeze(1)]
|
||||
|
@ -253,7 +258,7 @@ class Decoder(nn.Module):
|
|||
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def inference(self, inputs):
|
||||
def inference(self, inputs, speaker_embeddings=None):
|
||||
memory = self.get_go_frame(inputs)
|
||||
memory = self._update_memory(memory)
|
||||
|
||||
|
@ -266,6 +271,8 @@ class Decoder(nn.Module):
|
|||
stop_flags = [True, False, False]
|
||||
while True:
|
||||
memory = self.prenet(memory)
|
||||
if speaker_embeddings is not None:
|
||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||
mel_output, stop_token, alignment = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
import copy
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
@ -11,8 +12,8 @@ class Tacotron(nn.Module):
|
|||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
memory_size=5,
|
||||
attn_win=False,
|
||||
gst=False,
|
||||
|
@ -23,28 +24,33 @@ class Tacotron(nn.Module):
|
|||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True):
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 256
|
||||
encoder_dim = 512 if num_speakers > 1 else 256
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
# boilerplate model
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win,
|
||||
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet,
|
||||
proj_speaker_dim)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||
linear_dim)
|
||||
postnet_output_dim)
|
||||
# speaker embedding layers
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||
|
@ -82,27 +88,48 @@ class Tacotron(nn.Module):
|
|||
return inputs
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||
"""
|
||||
Shapes:
|
||||
- characters: B x T_in
|
||||
- text_lengths: B
|
||||
- mel_specs: B x T_out x D
|
||||
- speaker_ids: B x 1
|
||||
"""
|
||||
self._init_states()
|
||||
B = characters.size(0)
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
# B x T_in x embed_dim
|
||||
inputs = self.embedding(characters)
|
||||
self._init_states()
|
||||
# B x speaker_embed_dim
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
inputs = self._concat_speaker_embedding(inputs,
|
||||
self.speaker_embeddings)
|
||||
# B x T_in x encoder_dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, self.speaker_embeddings)
|
||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||
# decoder_outputs: B x decoder_dim x T_out
|
||||
# alignments: B x T_in x encoder_dim
|
||||
# stop_tokens: B x T_in
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask,
|
||||
self.speaker_embeddings_projected)
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||
# B x T_out x decoder_dim
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
# B x T_out x posnet_dim
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
# B x T_out x decoder_dim
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||
B = characters.size(0)
|
||||
|
@ -118,12 +145,19 @@ class Tacotron(nn.Module):
|
|||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, self.speaker_embeddings)
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs, self.speaker_embeddings_projected)
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||
decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _compute_speaker_embedding(self, speaker_ids):
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import copy
|
||||
import torch
|
||||
from math import sqrt
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
|
@ -10,6 +12,8 @@ class Tacotron2(nn.Module):
|
|||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
|
@ -18,10 +22,16 @@ class Tacotron2(nn.Module):
|
|||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True):
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.n_mel_channels = 80
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.n_frames_per_step = r
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 + 256 if num_speakers > 1 else 512
|
||||
encoder_dim = 512 + 256 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512)
|
||||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
|
@ -29,12 +39,18 @@ class Tacotron2(nn.Module):
|
|||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
location_attn, separate_stopnet, proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = Postnet(self.decoder_output_dim)
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
@staticmethod
|
||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
@ -43,19 +59,23 @@ class Tacotron2(nn.Module):
|
|||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
|
||||
self._init_states()
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths).to(text.device)
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder(
|
||||
decoder_outputs, stop_tokens, alignments = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, text, speaker_ids=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
|
@ -86,6 +106,13 @@ class Tacotron2(nn.Module):
|
|||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
|
|
28
train.py
28
train.py
|
@ -88,6 +88,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
'avg_loader_time': 0,
|
||||
'avg_alignment_score': 0
|
||||
}
|
||||
if c.bidirectional_decoder:
|
||||
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
|
||||
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
|
||||
keep_avg = KeepAverage()
|
||||
keep_avg.add_values(train_values)
|
||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
|
@ -150,8 +153,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
if c.bidirectional_decoder:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens,
|
||||
|
@ -174,6 +181,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
if not c.separate_stopnet and c.stopnet:
|
||||
loss += stop_loss
|
||||
|
||||
# backward decoder
|
||||
if c.bidirectional_decoder:
|
||||
if c.loss_masking:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||
else:
|
||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||
loss = decoder_backward_loss + decoder_c_loss
|
||||
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||
|
||||
loss.backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
|
@ -445,7 +462,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
|
@ -461,7 +477,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||
"stop_loss": keep_avg['avg_stop_loss']
|
||||
}
|
||||
|
||||
if c.bidirectional_decoder:
|
||||
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward']
|
||||
epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy()
|
||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
|
||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||
# test sentences
|
||||
|
|
|
@ -283,8 +283,8 @@ def setup_model(num_chars, num_speakers, c):
|
|||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
postnet_output_dim=c.audio['num_freq'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
memory_size=c.memory_size,
|
||||
attn_win=c.windowing,
|
||||
|
@ -295,11 +295,14 @@ def setup_model(num_chars, num_speakers, c):
|
|||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
|
@ -308,7 +311,8 @@ def setup_model(num_chars, num_speakers, c):
|
|||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue