commention model outputs for tacotron, align outputs shapes of tacotron and tracotron2, merge bidirectional decoder

This commit is contained in:
Eren Golge 2019-10-28 14:51:19 +01:00
parent 1c2d26ccb9
commit e83a4b07d2
7 changed files with 207 additions and 113 deletions

View File

@ -46,6 +46,7 @@
"forward_attn_mask": false,
"transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
@ -82,7 +83,8 @@
[
{
"name": "ljspeech",
"path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
// "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/",
"path": "/home/erogol/Data/LJSpeech-1.1",
"meta_file_train": "metadata_train.csv",
"meta_file_val": "metadata_val.csv"
}

View File

@ -1,7 +1,7 @@
# coding: utf-8
import torch
from torch import nn
from .common_layers import Prenet, Attention
from .common_layers import Prenet, Attention, Linear
class BatchNormConv1d(nn.Module):
@ -125,13 +125,12 @@ class CBHG(nn.Module):
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList([
BatchNormConv1d(
in_features,
conv_bank_features,
kernel_size=k,
stride=1,
padding=[(k - 1) // 2, k // 2],
activation=self.relu) for k in range(1, K + 1)
BatchNormConv1d(in_features,
conv_bank_features,
kernel_size=k,
stride=1,
padding=[(k - 1) // 2, k // 2],
activation=self.relu) for k in range(1, K + 1)
])
# max pooling of conv bank, with padding
# TODO: try average pooling OR larger kernel size
@ -142,39 +141,33 @@ class CBHG(nn.Module):
layer_set = []
for (in_size, out_size, ac) in zip(out_features, conv_projections,
activations):
layer = BatchNormConv1d(
in_size,
out_size,
kernel_size=3,
stride=1,
padding=[1, 1],
activation=ac)
layer = BatchNormConv1d(in_size,
out_size,
kernel_size=3,
stride=1,
padding=[1, 1],
activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
if self.highway_features != conv_projections[-1]:
self.pre_highway = nn.Linear(
conv_projections[-1], highway_features, bias=False)
self.pre_highway = nn.Linear(conv_projections[-1],
highway_features,
bias=False)
self.highways = nn.ModuleList([
Highway(highway_features, highway_features)
for _ in range(num_highways)
])
# bi-directional GPU layer
self.gru = nn.GRU(
gru_features,
gru_features,
1,
batch_first=True,
bidirectional=True)
self.gru = nn.GRU(gru_features,
gru_features,
1,
batch_first=True,
bidirectional=True)
def forward(self, inputs):
# (B, T_in, in_features)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_features, T_in)
if x.size(-1) == self.in_features:
x = x.transpose(1, 2)
# T = x.size(-1)
x = inputs
# (B, hid_features*K, T_in)
# Concat conv1d bank outputs
outs = []
@ -185,10 +178,8 @@ class CBHG(nn.Module):
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, hid_feature)
x = x.transpose(1, 2)
# Back to the original shape
x += inputs
x = x.transpose(1, 2)
if self.highway_features != self.conv_projections[-1]:
x = self.pre_highway(x)
# Residual connection
@ -236,8 +227,10 @@ class Encoder(nn.Module):
- inputs: batch x time x in_features
- outputs: batch x time x 128*2
"""
inputs = self.prenet(inputs)
return self.cbhg(inputs)
# B x T x prenet_dim
outputs = self.prenet(inputs)
outputs = self.cbhg(outputs.transpose(1, 2))
return outputs
class PostCBHG(nn.Module):
@ -314,7 +307,12 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
# learn init values instead of zero init.
self.stopnet = StopNet(256 + memory_dim * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(256 + memory_dim * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
def set_r(self, new_r):
self.r = new_r
@ -356,8 +354,9 @@ class Decoder(nn.Module):
def _parse_outputs(self, outputs, attentions, stop_tokens):
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
return outputs, attentions, stop_tokens
def decode(self, inputs, mask=None):
@ -438,9 +437,8 @@ class Decoder(nn.Module):
output, stop_token, attention = self.decode(inputs, mask)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
stop_tokens += [stop_token.squeeze(1)]
t += 1
return self._parse_outputs(outputs, attentions, stop_tokens)
def inference(self, inputs, speaker_embeddings=None):
@ -481,20 +479,20 @@ class Decoder(nn.Module):
return self._parse_outputs(outputs, attentions, stop_tokens)
class StopNet(nn.Module):
r"""
Args:
in_features (int): feature dimension of input.
"""
# class StopNet(nn.Module):
# r"""
# Args:
# in_features (int): feature dimension of input.
# """
def __init__(self, in_features):
super(StopNet, self).__init__()
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(in_features, 1)
torch.nn.init.xavier_uniform_(
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
# def __init__(self, in_features):
# super(StopNet, self).__init__()
# self.dropout = nn.Dropout(0.1)
# self.linear = nn.Linear(in_features, 1)
# torch.nn.init.xavier_uniform_(
# self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs):
outputs = self.dropout(inputs)
outputs = self.linear(outputs)
return outputs
# def forward(self, inputs):
# outputs = self.dropout(inputs)
# outputs = self.linear(outputs)
# return outputs

View File

@ -98,11 +98,12 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, separate_stopnet):
forward_attn_mask, location_attn, separate_stopnet,
speaker_embedding_dim):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.memory_dim = memory_dim
self.r_init = r
self.r = r
self.encoder_embedding_dim = in_features
@ -114,11 +115,15 @@ class Decoder(nn.Module):
self.gate_threshold = 0.5
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels,
prenet_type,
prenet_dropout,
[self.prenet_dim, self.prenet_dim],
bias=False)
# memory -> |Prenet| -> processed_memory
prenet_dim = self.memory_dim + speaker_embedding_dim
self.prenet = Prenet(
prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim)
@ -139,11 +144,11 @@ class Decoder(nn.Module):
self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
self.mel_channels * self.r_init)
self.memory_dim * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init,
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
@ -155,7 +160,7 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(B,
self.mel_channels * self.r)
self.memory_dim * self.r)
return memory
def _init_states(self, inputs, mask, keep_states=False):
@ -185,16 +190,14 @@ class Decoder(nn.Module):
def _parse_outputs(self, outputs, stop_tokens, alignments):
alignments = torch.stack(alignments).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
stop_tokens = stop_tokens.contiguous()
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(outputs.size(0), -1, self.mel_channels)
outputs = outputs.transpose(1, 2)
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
return outputs, stop_tokens, alignments
def _update_memory(self, memory):
if len(memory.shape) == 2:
return memory[:, self.mel_channels * (self.r - 1):]
return memory[:, :, self.mel_channels * (self.r - 1):]
return memory[:, self.memory_dim * (self.r - 1):]
return memory[:, :, self.memory_dim * (self.r - 1):]
def decode(self, memory):
query_input = torch.cat((memory, self.context), -1)
@ -228,10 +231,10 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
decoder_output = decoder_output[:, :self.r * self.mel_channels]
decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, stop_token, self.attention.attention_weights
def forward(self, inputs, memories, mask):
def forward(self, inputs, memories, mask, speaker_embeddings=None):
memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0)
@ -243,6 +246,8 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)]
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, attention_weights = self.decode(memory)
outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)]
@ -253,7 +258,7 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments
def inference(self, inputs):
def inference(self, inputs, speaker_embeddings=None):
memory = self.get_go_frame(inputs)
memory = self._update_memory(memory)
@ -266,6 +271,8 @@ class Decoder(nn.Module):
stop_flags = [True, False, False]
while True:
memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]

View File

@ -1,5 +1,6 @@
# coding: utf-8
import torch
import copy
from torch import nn
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
from TTS.utils.generic_utils import sequence_mask
@ -11,8 +12,8 @@ class Tacotron(nn.Module):
num_chars,
num_speakers,
r=5,
linear_dim=1025,
mel_dim=80,
postnet_output_dim=1025,
decoder_output_dim=80,
memory_size=5,
attn_win=False,
gst=False,
@ -23,28 +24,33 @@ class Tacotron(nn.Module):
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True):
separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.gst = gst
self.num_speakers = num_speakers
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 if num_speakers > 1 else 256
encoder_dim = 512 if num_speakers > 1 else 256
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
# boilerplate model
self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win,
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet,
proj_speaker_dim)
self.postnet = PostCBHG(mel_dim)
if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
linear_dim)
postnet_output_dim)
# speaker embedding layers
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256)
@ -82,27 +88,48 @@ class Tacotron(nn.Module):
return inputs
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
"""
Shapes:
- characters: B x T_in
- text_lengths: B
- mel_specs: B x T_out x D
- speaker_ids: B x 1
"""
self._init_states()
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
# B x T_in x embed_dim
inputs = self.embedding(characters)
self._init_states()
# B x speaker_embed_dim
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
# B x T_in x embed_dim + speaker_embed_dim
inputs = self._concat_speaker_embedding(inputs,
self.speaker_embeddings)
# B x T_in x encoder_dim
encoder_outputs = self.encoder(inputs)
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
if self.num_speakers > 1:
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder(
# decoder_outputs: B x decoder_dim x T_out
# alignments: B x T_in x encoder_dim
# stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask,
self.speaker_embeddings_projected)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
# B x T_out x decoder_dim
postnet_outputs = self.postnet(decoder_outputs)
# B x T_out x posnet_dim
postnet_outputs = self.last_linear(postnet_outputs)
# B x T_out x decoder_dim
decoder_outputs = decoder_outputs.transpose(1, 2)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, characters, speaker_ids=None, style_mel=None):
B = characters.size(0)
@ -118,12 +145,19 @@ class Tacotron(nn.Module):
if self.num_speakers > 1:
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder.inference(
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs, self.speaker_embeddings_projected)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = self.last_linear(postnet_outputs)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
return decoder_outputs_b, alignments_b
def _compute_speaker_embedding(self, speaker_ids):
speaker_embeddings = self.speaker_embedding(speaker_ids)

View File

@ -1,3 +1,5 @@
import copy
import torch
from math import sqrt
from torch import nn
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
@ -10,6 +12,8 @@ class Tacotron2(nn.Module):
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_win=False,
attn_norm="softmax",
prenet_type="original",
@ -18,10 +22,16 @@ class Tacotron2(nn.Module):
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True):
separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r
self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 + 256 if num_speakers > 1 else 512
encoder_dim = 512 + 256 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 512)
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
@ -29,12 +39,18 @@ class Tacotron2(nn.Module):
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet)
self.postnet = Postnet(self.n_mel_channels)
location_attn, separate_stopnet, proj_speaker_dim)
if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = Postnet(self.decoder_output_dim)
def _init_states(self):
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
@staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
@ -43,19 +59,23 @@ class Tacotron2(nn.Module):
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
self._init_states()
# compute mask for padding
mask = sequence_mask(text_lengths).to(text.device)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder(
decoder_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, text, speaker_ids=None):
embedded_inputs = self.embedding(text).transpose(1, 2)
@ -86,6 +106,13 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
return decoder_outputs_b, alignments_b
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")

View File

@ -88,6 +88,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
'avg_loader_time': 0,
'avg_alignment_score': 0
}
if c.bidirectional_decoder:
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
keep_avg = KeepAverage()
keep_avg.add_values(train_values)
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
@ -150,8 +153,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
speaker_ids = speaker_ids.cuda(non_blocking=True)
# forward pass model
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
if c.bidirectional_decoder:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
# loss computation
stop_loss = criterion_st(stop_tokens,
@ -174,6 +181,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if not c.separate_stopnet and c.stopnet:
loss += stop_loss
# backward decoder
if c.bidirectional_decoder:
if c.loss_masking:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
else:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
loss = decoder_backward_loss + decoder_c_loss
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
loss.backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip)
@ -445,7 +462,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)
}
tb_logger.tb_eval_figures(global_step, eval_figures)
# Sample audio
if c.model in ["Tacotron", "TacotronGST"]:
@ -461,7 +477,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
"loss_decoder": keep_avg['avg_decoder_loss'],
"stop_loss": keep_avg['avg_stop_loss']
}
if c.bidirectional_decoder:
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward']
epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy()
tb_logger.tb_eval_stats(global_step, epoch_stats)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs:
# test sentences

View File

@ -283,8 +283,8 @@ def setup_model(num_chars, num_speakers, c):
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
linear_dim=1025,
mel_dim=80,
postnet_output_dim=c.audio['num_freq'],
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
memory_size=c.memory_size,
attn_win=c.windowing,
@ -295,11 +295,14 @@ def setup_model(num_chars, num_speakers, c):
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
@ -308,7 +311,8 @@ def setup_model(num_chars, num_speakers, c):
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
return model