commention model outputs for tacotron, align outputs shapes of tacotron and tracotron2, merge bidirectional decoder

This commit is contained in:
Eren Golge 2019-10-28 14:51:19 +01:00
parent 1c2d26ccb9
commit e83a4b07d2
7 changed files with 207 additions and 113 deletions

View File

@ -46,6 +46,7 @@
"forward_attn_mask": false, "forward_attn_mask": false,
"transition_agent": false, // enable/disable transition agent of forward attention. "transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"loss_masking": true, // enable / disable loss masking against the sequence padding. "loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis. "stopnet": true, // Train stopnet predicting the end of synthesis.
@ -82,7 +83,8 @@
[ [
{ {
"name": "ljspeech", "name": "ljspeech",
"path": "/data/ro/shared/data/keithito/LJSpeech-1.1/", // "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
"path": "/home/erogol/Data/LJSpeech-1.1",
"meta_file_train": "metadata_train.csv", "meta_file_train": "metadata_train.csv",
"meta_file_val": "metadata_val.csv" "meta_file_val": "metadata_val.csv"
} }

View File

@ -1,7 +1,7 @@
# coding: utf-8 # coding: utf-8
import torch import torch
from torch import nn from torch import nn
from .common_layers import Prenet, Attention from .common_layers import Prenet, Attention, Linear
class BatchNormConv1d(nn.Module): class BatchNormConv1d(nn.Module):
@ -125,13 +125,12 @@ class CBHG(nn.Module):
# list of conv1d bank with filter size k=1...K # list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead # TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList([ self.conv1d_banks = nn.ModuleList([
BatchNormConv1d( BatchNormConv1d(in_features,
in_features, conv_bank_features,
conv_bank_features, kernel_size=k,
kernel_size=k, stride=1,
stride=1, padding=[(k - 1) // 2, k // 2],
padding=[(k - 1) // 2, k // 2], activation=self.relu) for k in range(1, K + 1)
activation=self.relu) for k in range(1, K + 1)
]) ])
# max pooling of conv bank, with padding # max pooling of conv bank, with padding
# TODO: try average pooling OR larger kernel size # TODO: try average pooling OR larger kernel size
@ -142,39 +141,33 @@ class CBHG(nn.Module):
layer_set = [] layer_set = []
for (in_size, out_size, ac) in zip(out_features, conv_projections, for (in_size, out_size, ac) in zip(out_features, conv_projections,
activations): activations):
layer = BatchNormConv1d( layer = BatchNormConv1d(in_size,
in_size, out_size,
out_size, kernel_size=3,
kernel_size=3, stride=1,
stride=1, padding=[1, 1],
padding=[1, 1], activation=ac)
activation=ac)
layer_set.append(layer) layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set) self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers # setup Highway layers
if self.highway_features != conv_projections[-1]: if self.highway_features != conv_projections[-1]:
self.pre_highway = nn.Linear( self.pre_highway = nn.Linear(conv_projections[-1],
conv_projections[-1], highway_features, bias=False) highway_features,
bias=False)
self.highways = nn.ModuleList([ self.highways = nn.ModuleList([
Highway(highway_features, highway_features) Highway(highway_features, highway_features)
for _ in range(num_highways) for _ in range(num_highways)
]) ])
# bi-directional GPU layer # bi-directional GPU layer
self.gru = nn.GRU( self.gru = nn.GRU(gru_features,
gru_features, gru_features,
gru_features, 1,
1, batch_first=True,
batch_first=True, bidirectional=True)
bidirectional=True)
def forward(self, inputs): def forward(self, inputs):
# (B, T_in, in_features)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_features, T_in) # (B, in_features, T_in)
if x.size(-1) == self.in_features: x = inputs
x = x.transpose(1, 2)
# T = x.size(-1)
# (B, hid_features*K, T_in) # (B, hid_features*K, T_in)
# Concat conv1d bank outputs # Concat conv1d bank outputs
outs = [] outs = []
@ -185,10 +178,8 @@ class CBHG(nn.Module):
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
for conv1d in self.conv1d_projections: for conv1d in self.conv1d_projections:
x = conv1d(x) x = conv1d(x)
# (B, T_in, hid_feature)
x = x.transpose(1, 2)
# Back to the original shape
x += inputs x += inputs
x = x.transpose(1, 2)
if self.highway_features != self.conv_projections[-1]: if self.highway_features != self.conv_projections[-1]:
x = self.pre_highway(x) x = self.pre_highway(x)
# Residual connection # Residual connection
@ -236,8 +227,10 @@ class Encoder(nn.Module):
- inputs: batch x time x in_features - inputs: batch x time x in_features
- outputs: batch x time x 128*2 - outputs: batch x time x 128*2
""" """
inputs = self.prenet(inputs) # B x T x prenet_dim
return self.cbhg(inputs) outputs = self.prenet(inputs)
outputs = self.cbhg(outputs.transpose(1, 2))
return outputs
class PostCBHG(nn.Module): class PostCBHG(nn.Module):
@ -314,7 +307,12 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec # RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init) self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
# learn init values instead of zero init. # learn init values instead of zero init.
self.stopnet = StopNet(256 + memory_dim * self.r_init) self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(256 + memory_dim * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
def set_r(self, new_r): def set_r(self, new_r):
self.r = new_r self.r = new_r
@ -356,8 +354,9 @@ class Decoder(nn.Module):
def _parse_outputs(self, outputs, attentions, stop_tokens): def _parse_outputs(self, outputs, attentions, stop_tokens):
# Back to batch first # Back to batch first
attentions = torch.stack(attentions).transpose(0, 1) attentions = torch.stack(attentions).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
return outputs, attentions, stop_tokens return outputs, attentions, stop_tokens
def decode(self, inputs, mask=None): def decode(self, inputs, mask=None):
@ -438,9 +437,8 @@ class Decoder(nn.Module):
output, stop_token, attention = self.decode(inputs, mask) output, stop_token, attention = self.decode(inputs, mask)
outputs += [output] outputs += [output]
attentions += [attention] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token.squeeze(1)]
t += 1 t += 1
return self._parse_outputs(outputs, attentions, stop_tokens) return self._parse_outputs(outputs, attentions, stop_tokens)
def inference(self, inputs, speaker_embeddings=None): def inference(self, inputs, speaker_embeddings=None):
@ -481,20 +479,20 @@ class Decoder(nn.Module):
return self._parse_outputs(outputs, attentions, stop_tokens) return self._parse_outputs(outputs, attentions, stop_tokens)
class StopNet(nn.Module): # class StopNet(nn.Module):
r""" # r"""
Args: # Args:
in_features (int): feature dimension of input. # in_features (int): feature dimension of input.
""" # """
def __init__(self, in_features): # def __init__(self, in_features):
super(StopNet, self).__init__() # super(StopNet, self).__init__()
self.dropout = nn.Dropout(0.1) # self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(in_features, 1) # self.linear = nn.Linear(in_features, 1)
torch.nn.init.xavier_uniform_( # torch.nn.init.xavier_uniform_(
self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) # self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs): # def forward(self, inputs):
outputs = self.dropout(inputs) # outputs = self.dropout(inputs)
outputs = self.linear(outputs) # outputs = self.linear(outputs)
return outputs # return outputs

View File

@ -98,11 +98,12 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
# Pylint gets confused by PyTorch conventions here # Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init #pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent, prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, separate_stopnet): forward_attn_mask, location_attn, separate_stopnet,
speaker_embedding_dim):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.mel_channels = inputs_dim self.memory_dim = memory_dim
self.r_init = r self.r_init = r
self.r = r self.r = r
self.encoder_embedding_dim = in_features self.encoder_embedding_dim = in_features
@ -114,11 +115,15 @@ class Decoder(nn.Module):
self.gate_threshold = 0.5 self.gate_threshold = 0.5
self.p_attention_dropout = 0.1 self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1 self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels,
prenet_type, # memory -> |Prenet| -> processed_memory
prenet_dropout, prenet_dim = self.memory_dim + speaker_embedding_dim
[self.prenet_dim, self.prenet_dim], self.prenet = Prenet(
bias=False) prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim) self.query_dim)
@ -139,11 +144,11 @@ class Decoder(nn.Module):
self.decoder_rnn_dim, 1) self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features, self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
self.mel_channels * self.r_init) self.memory_dim * self.r_init)
self.stopnet = nn.Sequential( self.stopnet = nn.Sequential(
nn.Dropout(0.1), nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init, Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
1, 1,
bias=True, bias=True,
init_gain='sigmoid')) init_gain='sigmoid'))
@ -155,7 +160,7 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs): def get_go_frame(self, inputs):
B = inputs.size(0) B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(B, memory = torch.zeros(1, device=inputs.device).repeat(B,
self.mel_channels * self.r) self.memory_dim * self.r)
return memory return memory
def _init_states(self, inputs, mask, keep_states=False): def _init_states(self, inputs, mask, keep_states=False):
@ -185,16 +190,14 @@ class Decoder(nn.Module):
def _parse_outputs(self, outputs, stop_tokens, alignments): def _parse_outputs(self, outputs, stop_tokens, alignments):
alignments = torch.stack(alignments).transpose(0, 1) alignments = torch.stack(alignments).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
stop_tokens = stop_tokens.contiguous()
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(outputs.size(0), -1, self.mel_channels) outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
outputs = outputs.transpose(1, 2)
return outputs, stop_tokens, alignments return outputs, stop_tokens, alignments
def _update_memory(self, memory): def _update_memory(self, memory):
if len(memory.shape) == 2: if len(memory.shape) == 2:
return memory[:, self.mel_channels * (self.r - 1):] return memory[:, self.memory_dim * (self.r - 1):]
return memory[:, :, self.mel_channels * (self.r - 1):] return memory[:, :, self.memory_dim * (self.r - 1):]
def decode(self, memory): def decode(self, memory):
query_input = torch.cat((memory, self.context), -1) query_input = torch.cat((memory, self.context), -1)
@ -228,10 +231,10 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach()) stop_token = self.stopnet(stopnet_input.detach())
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
decoder_output = decoder_output[:, :self.r * self.mel_channels] decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, stop_token, self.attention.attention_weights return decoder_output, stop_token, self.attention.attention_weights
def forward(self, inputs, memories, mask): def forward(self, inputs, memories, mask, speaker_embeddings=None):
memory = self.get_go_frame(inputs).unsqueeze(0) memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories) memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0) memories = torch.cat((memory, memories), dim=0)
@ -243,6 +246,8 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], [] outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1: while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)] memory = memories[len(outputs)]
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, attention_weights = self.decode(memory) mel_output, stop_token, attention_weights = self.decode(memory)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)] stop_tokens += [stop_token.squeeze(1)]
@ -253,7 +258,7 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments return outputs, stop_tokens, alignments
def inference(self, inputs): def inference(self, inputs, speaker_embeddings=None):
memory = self.get_go_frame(inputs) memory = self.get_go_frame(inputs)
memory = self._update_memory(memory) memory = self._update_memory(memory)
@ -266,6 +271,8 @@ class Decoder(nn.Module):
stop_flags = [True, False, False] stop_flags = [True, False, False]
while True: while True:
memory = self.prenet(memory) memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, alignment = self.decode(memory) mel_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
import torch import torch
import copy
from torch import nn from torch import nn
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
from TTS.utils.generic_utils import sequence_mask from TTS.utils.generic_utils import sequence_mask
@ -11,8 +12,8 @@ class Tacotron(nn.Module):
num_chars, num_chars,
num_speakers, num_speakers,
r=5, r=5,
linear_dim=1025, postnet_output_dim=1025,
mel_dim=80, decoder_output_dim=80,
memory_size=5, memory_size=5,
attn_win=False, attn_win=False,
gst=False, gst=False,
@ -23,28 +24,33 @@ class Tacotron(nn.Module):
trans_agent=False, trans_agent=False,
forward_attn_mask=False, forward_attn_mask=False,
location_attn=True, location_attn=True,
separate_stopnet=True): separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.decoder_output_dim = decoder_output_dim
self.linear_dim = linear_dim self.postnet_output_dim = postnet_output_dim
self.gst = gst self.gst = gst
self.num_speakers = num_speakers self.num_speakers = num_speakers
self.embedding = nn.Embedding(num_chars, 256) self.bidirectional_decoder = bidirectional_decoder
self.embedding.weight.data.normal_(0, 0.3)
decoder_dim = 512 if num_speakers > 1 else 256 decoder_dim = 512 if num_speakers > 1 else 256
encoder_dim = 512 if num_speakers > 1 else 256 encoder_dim = 512 if num_speakers > 1 else 256
proj_speaker_dim = 80 if num_speakers > 1 else 0 proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
# boilerplate model # boilerplate model
self.encoder = Encoder(encoder_dim) self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask, forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet, location_attn, separate_stopnet,
proj_speaker_dim) proj_speaker_dim)
self.postnet = PostCBHG(mel_dim) if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
linear_dim) postnet_output_dim)
# speaker embedding layers # speaker embedding layers
if num_speakers > 1: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256) self.speaker_embedding = nn.Embedding(num_speakers, 256)
@ -82,27 +88,48 @@ class Tacotron(nn.Module):
return inputs return inputs
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
"""
Shapes:
- characters: B x T_in
- text_lengths: B
- mel_specs: B x T_out x D
- speaker_ids: B x 1
"""
self._init_states()
B = characters.size(0) B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device) mask = sequence_mask(text_lengths).to(characters.device)
# B x T_in x embed_dim
inputs = self.embedding(characters) inputs = self.embedding(characters)
self._init_states() # B x speaker_embed_dim
self.compute_speaker_embedding(speaker_ids) self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1: if self.num_speakers > 1:
# B x T_in x embed_dim + speaker_embed_dim
inputs = self._concat_speaker_embedding(inputs, inputs = self._concat_speaker_embedding(inputs,
self.speaker_embeddings) self.speaker_embeddings)
# B x T_in x encoder_dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
if self.gst: if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
if self.num_speakers > 1: if self.num_speakers > 1:
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, self.speaker_embeddings) encoder_outputs, self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder( # decoder_outputs: B x decoder_dim x T_out
# alignments: B x T_in x encoder_dim
# stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask, encoder_outputs, mel_specs, mask,
self.speaker_embeddings_projected) self.speaker_embeddings_projected)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) # B x T_out x decoder_dim
linear_outputs = self.postnet(mel_outputs) postnet_outputs = self.postnet(decoder_outputs)
linear_outputs = self.last_linear(linear_outputs) # B x T_out x posnet_dim
return mel_outputs, linear_outputs, alignments, stop_tokens postnet_outputs = self.last_linear(postnet_outputs)
# B x T_out x decoder_dim
decoder_outputs = decoder_outputs.transpose(1, 2)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, characters, speaker_ids=None, style_mel=None): def inference(self, characters, speaker_ids=None, style_mel=None):
B = characters.size(0) B = characters.size(0)
@ -118,12 +145,19 @@ class Tacotron(nn.Module):
if self.num_speakers > 1: if self.num_speakers > 1:
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, self.speaker_embeddings) encoder_outputs, self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder.inference( decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs, self.speaker_embeddings_projected) encoder_outputs, self.speaker_embeddings_projected)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim)
linear_outputs = self.postnet(mel_outputs) postnet_outputs = self.postnet(decoder_outputs)
linear_outputs = self.last_linear(linear_outputs) postnet_outputs = self.last_linear(postnet_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens return decoder_outputs, postnet_outputs, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
return decoder_outputs_b, alignments_b
def _compute_speaker_embedding(self, speaker_ids): def _compute_speaker_embedding(self, speaker_ids):
speaker_embeddings = self.speaker_embedding(speaker_ids) speaker_embeddings = self.speaker_embedding(speaker_ids)

View File

@ -1,3 +1,5 @@
import copy
import torch
from math import sqrt from math import sqrt
from torch import nn from torch import nn
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
@ -10,6 +12,8 @@ class Tacotron2(nn.Module):
num_chars, num_chars,
num_speakers, num_speakers,
r, r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_win=False, attn_win=False,
attn_norm="softmax", attn_norm="softmax",
prenet_type="original", prenet_type="original",
@ -18,10 +22,16 @@ class Tacotron2(nn.Module):
trans_agent=False, trans_agent=False,
forward_attn_mask=False, forward_attn_mask=False,
location_attn=True, location_attn=True,
separate_stopnet=True): separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron2, self).__init__() super(Tacotron2, self).__init__()
self.n_mel_channels = 80 self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r self.n_frames_per_step = r
self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 + 256 if num_speakers > 1 else 512
encoder_dim = 512 + 256 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 512) self.embedding = nn.Embedding(num_chars, 512)
std = sqrt(2.0 / (num_chars + 512)) std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
@ -29,12 +39,18 @@ class Tacotron2(nn.Module):
if num_speakers > 1: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512) self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask, forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet) location_attn, separate_stopnet, proj_speaker_dim)
self.postnet = Postnet(self.n_mel_channels) if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = Postnet(self.decoder_output_dim)
def _init_states(self):
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
@staticmethod @staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
@ -43,19 +59,23 @@ class Tacotron2(nn.Module):
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None): def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
self._init_states()
# compute mask for padding # compute mask for padding
mask = sequence_mask(text_lengths).to(text.device) mask = sequence_mask(text_lengths).to(text.device)
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs, encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids) speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder( decoder_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs) postnet_outputs = self.postnet(decoder_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet postnet_outputs = decoder_outputs + postnet_outputs
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments) decoder_outputs, postnet_outputs, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, text, speaker_ids=None): def inference(self, text, speaker_ids=None):
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
@ -86,6 +106,13 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments) mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
return decoder_outputs_b, alignments_b
def _add_speaker_embedding(self, encoder_outputs, speaker_ids): def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is None: if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")

View File

@ -88,6 +88,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
'avg_loader_time': 0, 'avg_loader_time': 0,
'avg_alignment_score': 0 'avg_alignment_score': 0
} }
if c.bidirectional_decoder:
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
keep_avg = KeepAverage() keep_avg = KeepAverage()
keep_avg.add_values(train_values) keep_avg.add_values(train_values)
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
@ -150,8 +153,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
speaker_ids = speaker_ids.cuda(non_blocking=True) speaker_ids = speaker_ids.cuda(non_blocking=True)
# forward pass model # forward pass model
decoder_output, postnet_output, alignments, stop_tokens = model( if c.bidirectional_decoder:
text_input, text_lengths, mel_input, speaker_ids=speaker_ids) decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_loss = criterion_st(stop_tokens,
@ -174,6 +181,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if not c.separate_stopnet and c.stopnet: if not c.separate_stopnet and c.stopnet:
loss += stop_loss loss += stop_loss
# backward decoder
if c.bidirectional_decoder:
if c.loss_masking:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
else:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
loss = decoder_backward_loss + decoder_c_loss
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
loss.backward() loss.backward()
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
@ -445,7 +462,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
"ground_truth": plot_spectrogram(gt_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img) "alignment": plot_alignment(align_img)
} }
tb_logger.tb_eval_figures(global_step, eval_figures)
# Sample audio # Sample audio
if c.model in ["Tacotron", "TacotronGST"]: if c.model in ["Tacotron", "TacotronGST"]:
@ -461,7 +477,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
"loss_decoder": keep_avg['avg_decoder_loss'], "loss_decoder": keep_avg['avg_decoder_loss'],
"stop_loss": keep_avg['avg_stop_loss'] "stop_loss": keep_avg['avg_stop_loss']
} }
if c.bidirectional_decoder:
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward']
epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy()
tb_logger.tb_eval_stats(global_step, epoch_stats) tb_logger.tb_eval_stats(global_step, epoch_stats)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs: if args.rank == 0 and epoch > c.test_delay_epochs:
# test sentences # test sentences

View File

@ -283,8 +283,8 @@ def setup_model(num_chars, num_speakers, c):
model = MyModel(num_chars=num_chars, model = MyModel(num_chars=num_chars,
num_speakers=num_speakers, num_speakers=num_speakers,
r=c.r, r=c.r,
linear_dim=1025, postnet_output_dim=c.audio['num_freq'],
mel_dim=80, decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst, gst=c.use_gst,
memory_size=c.memory_size, memory_size=c.memory_size,
attn_win=c.windowing, attn_win=c.windowing,
@ -295,11 +295,14 @@ def setup_model(num_chars, num_speakers, c):
trans_agent=c.transition_agent, trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask, forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn, location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet) separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars, model = MyModel(num_chars=num_chars,
num_speakers=num_speakers, num_speakers=num_speakers,
r=c.r, r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
prenet_type=c.prenet_type, prenet_type=c.prenet_type,
@ -308,7 +311,8 @@ def setup_model(num_chars, num_speakers, c):
trans_agent=c.transition_agent, trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask, forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn, location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet) separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
return model return model