mirror of https://github.com/coqui-ai/TTS.git
bug fixes for forward backward training and load_data for parsing data_loader
This commit is contained in:
parent
e83a4b07d2
commit
5a56a2c096
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "ljspeech",
|
"run_name": "ljspeech",
|
||||||
"run_description": "Tacotron ljspeech release training",
|
"run_description": "t bidirectional decoder test train",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
|
@ -34,7 +34,7 @@
|
||||||
"model": "Tacotron", // one of the model in models/
|
"model": "Tacotron", // one of the model in models/
|
||||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||||
|
@ -56,7 +56,7 @@
|
||||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||||
"gradual_training": [[0, 7, 32], [1, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
"gradual_training": [[0, 7, 96], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||||
"wd": 0.000001, // Weight decay weight.
|
"wd": 0.000001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||||
|
@ -68,7 +68,8 @@
|
||||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
// "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
||||||
|
"output_path": "/media/erogol/data_ssd/Models/runs/",
|
||||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||||
|
|
|
@ -307,12 +307,7 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.stopnet = nn.Sequential(
|
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||||
nn.Dropout(0.1),
|
|
||||||
Linear(256 + memory_dim * self.r_init,
|
|
||||||
1,
|
|
||||||
bias=True,
|
|
||||||
init_gain='sigmoid'))
|
|
||||||
|
|
||||||
def set_r(self, new_r):
|
def set_r(self, new_r):
|
||||||
self.r = new_r
|
self.r = new_r
|
||||||
|
@ -321,11 +316,9 @@ class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Reshape the spectrograms for given 'r'
|
Reshape the spectrograms for given 'r'
|
||||||
"""
|
"""
|
||||||
B = memory.shape[0]
|
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.contiguous()
|
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
return memory
|
return memory
|
||||||
|
@ -356,7 +349,9 @@ class Decoder(nn.Module):
|
||||||
attentions = torch.stack(attentions).transpose(0, 1)
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
outputs = outputs.view(
|
||||||
|
outputs.size(0), -1, self.memory_dim)
|
||||||
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, attentions, stop_tokens
|
return outputs, attentions, stop_tokens
|
||||||
|
|
||||||
def decode(self, inputs, mask=None):
|
def decode(self, inputs, mask=None):
|
||||||
|
@ -405,6 +400,7 @@ class Decoder(nn.Module):
|
||||||
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
||||||
else:
|
else:
|
||||||
# use only the last frame prediction
|
# use only the last frame prediction
|
||||||
|
# assert new_memory.shape[-1] == self.r * self.memory_dim
|
||||||
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
|
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
||||||
|
@ -479,20 +475,20 @@ class Decoder(nn.Module):
|
||||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||||
|
|
||||||
|
|
||||||
# class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
# r"""
|
r"""
|
||||||
# Args:
|
Args:
|
||||||
# in_features (int): feature dimension of input.
|
in_features (int): feature dimension of input.
|
||||||
# """
|
"""
|
||||||
|
|
||||||
# def __init__(self, in_features):
|
def __init__(self, in_features):
|
||||||
# super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
# self.dropout = nn.Dropout(0.1)
|
self.dropout = nn.Dropout(0.1)
|
||||||
# self.linear = nn.Linear(in_features, 1)
|
self.linear = nn.Linear(in_features, 1)
|
||||||
# torch.nn.init.xavier_uniform_(
|
torch.nn.init.xavier_uniform_(
|
||||||
# self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||||
|
|
||||||
# def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
# outputs = self.dropout(inputs)
|
outputs = self.dropout(inputs)
|
||||||
# outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
# return outputs
|
return outputs
|
||||||
|
|
|
@ -181,17 +181,23 @@ class Decoder(nn.Module):
|
||||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
||||||
def _reshape_memory(self, memories):
|
def _reshape_memory(self, memory):
|
||||||
memories = memories.view(memories.size(0),
|
"""
|
||||||
int(memories.size(1) / self.r), -1)
|
Reshape the spectrograms for given 'r'
|
||||||
memories = memories.transpose(0, 1)
|
"""
|
||||||
return memories
|
# Grouping multiple frames if necessary
|
||||||
|
if memory.size(-1) == self.memory_dim:
|
||||||
|
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||||
|
# Time first (T_decoder, B, memory_dim)
|
||||||
|
memory = memory.transpose(0, 1)
|
||||||
|
return memory
|
||||||
|
|
||||||
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
outputs = outputs.view(outputs.size(0), self.memory_dim, -1)
|
outputs = outputs.view(outputs.size(0), -1, self.memory_dim)
|
||||||
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def _update_memory(self, memory):
|
def _update_memory(self, memory):
|
||||||
|
|
|
@ -125,7 +125,7 @@ class Tacotron(nn.Module):
|
||||||
# B x T_out x posnet_dim
|
# B x T_out x posnet_dim
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
# B x T_out x decoder_dim
|
# B x T_out x decoder_dim
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||||
|
@ -139,7 +139,7 @@ class Tacotron(nn.Module):
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
inputs = self._concat_speaker_embedding(inputs,
|
inputs = self._concat_speaker_embedding(inputs,
|
||||||
self.speaker_embeddings)
|
self.speaker_embeddings)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst and style_mel is not None:
|
if self.gst and style_mel is not None:
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
|
@ -147,16 +147,16 @@ class Tacotron(nn.Module):
|
||||||
encoder_outputs, self.speaker_embeddings)
|
encoder_outputs, self.speaker_embeddings)
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs, self.speaker_embeddings_projected)
|
encoder_outputs, self.speaker_embeddings_projected)
|
||||||
decoder_outputs = decoder_outputs.view(B, -1, self.decoder_output_dim)
|
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||||
self.speaker_embeddings_projected)
|
self.speaker_embeddings_projected)
|
||||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||||
return decoder_outputs_b, alignments_b
|
return decoder_outputs_b, alignments_b
|
||||||
|
|
||||||
def _compute_speaker_embedding(self, speaker_ids):
|
def _compute_speaker_embedding(self, speaker_ids):
|
||||||
|
|
193
train.py
193
train.py
|
@ -40,12 +40,12 @@ print(" > Using CUDA: ", use_cuda)
|
||||||
print(" > Number of GPUs: ", num_gpus)
|
print(" > Number of GPUs: ", num_gpus)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
c.r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
|
@ -72,11 +72,54 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def format_data(data):
|
||||||
ap, global_step, epoch):
|
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
|
|
||||||
|
# setup input data
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
speaker_names = data[2]
|
||||||
|
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
|
||||||
|
] else None
|
||||||
|
mel_input = data[4]
|
||||||
|
mel_lengths = data[5]
|
||||||
|
stop_targets = data[6]
|
||||||
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
|
if c.use_speaker_embedding:
|
||||||
|
speaker_ids = [
|
||||||
|
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||||
|
]
|
||||||
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
else:
|
||||||
|
speaker_ids = None
|
||||||
|
|
||||||
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) >
|
||||||
|
0.0).unsqueeze(2).float().squeeze(2)
|
||||||
|
|
||||||
|
# dispatch data to GPU
|
||||||
|
if use_cuda:
|
||||||
|
text_input = text_input.cuda(non_blocking=True)
|
||||||
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||||
|
mel_input = mel_input.cuda(non_blocking=True)
|
||||||
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||||
|
linear_input = linear_input.cuda(
|
||||||
|
non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
|
||||||
|
] else None
|
||||||
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||||
|
if speaker_ids is not None:
|
||||||
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
|
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
ap, global_step, epoch):
|
||||||
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
train_values = {
|
train_values = {
|
||||||
|
@ -103,33 +146,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# format data
|
||||||
text_input = data[0]
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
|
||||||
text_lengths = data[1]
|
|
||||||
speaker_names = data[2]
|
|
||||||
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
|
|
||||||
] else None
|
|
||||||
mel_input = data[4]
|
|
||||||
mel_lengths = data[5]
|
|
||||||
stop_targets = data[6]
|
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
|
||||||
speaker_ids = [
|
|
||||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
|
||||||
]
|
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
|
||||||
else:
|
|
||||||
speaker_ids = None
|
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze(2)
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
|
@ -139,19 +159,6 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if optimizer_st:
|
if optimizer_st:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
# dispatch data to GPU
|
|
||||||
if use_cuda:
|
|
||||||
text_input = text_input.cuda(non_blocking=True)
|
|
||||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
|
||||||
mel_input = mel_input.cuda(non_blocking=True)
|
|
||||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
|
||||||
linear_input = linear_input.cuda(
|
|
||||||
non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
|
|
||||||
] else None
|
|
||||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
|
||||||
if speaker_ids is not None:
|
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
|
@ -188,7 +195,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
else:
|
else:
|
||||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||||
loss = decoder_backward_loss + decoder_c_loss
|
loss += decoder_backward_loss + decoder_c_loss
|
||||||
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -278,7 +285,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(const_spec, ap),
|
"prediction": plot_spectrogram(const_spec, ap),
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
"alignment": plot_alignment(align_img)
|
"alignment": plot_alignment(align_img),
|
||||||
|
"alignment_backward": plot_alignment(alignments_backward[0].data.cpu().numpy())
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
|
@ -320,7 +328,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=True)
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -331,67 +339,29 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
'avg_stop_loss': 0,
|
'avg_stop_loss': 0,
|
||||||
'avg_align_score': 0
|
'avg_align_score': 0
|
||||||
}
|
}
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
eval_values_dict['avg_decoder_b_loss'] = 0 # decoder backward loss
|
||||||
|
eval_values_dict['avg_decoder_c_loss'] = 0 # decoder consistency loss
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
keep_avg.add_values(eval_values_dict)
|
keep_avg.add_values(eval_values_dict)
|
||||||
print("\n > Validation")
|
print("\n > Validation")
|
||||||
if c.test_sentences_file is None:
|
|
||||||
test_sentences = [
|
|
||||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
|
||||||
"Be a voice, not an echo.",
|
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
|
||||||
"This cake is great. It's so delicious and moist."
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
with open(c.test_sentences_file, "r") as f:
|
|
||||||
test_sentences = [s.strip() for s in f.readlines()]
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if data_loader is not None:
|
if data_loader is not None:
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# format data
|
||||||
text_input = data[0]
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
|
||||||
text_lengths = data[1]
|
assert mel_input.shape[1] % model.decoder.r == 0
|
||||||
speaker_names = data[2]
|
|
||||||
linear_input = data[3] if c.model in [
|
|
||||||
"Tacotron", "TacotronGST"
|
|
||||||
] else None
|
|
||||||
mel_input = data[4]
|
|
||||||
mel_lengths = data[5]
|
|
||||||
stop_targets = data[6]
|
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
# forward pass model
|
||||||
speaker_ids = [
|
if c.bidirectional_decoder:
|
||||||
speaker_mapping[speaker_name]
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
for speaker_name in speaker_names
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
]
|
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
|
||||||
stop_targets.size(1) // c.r,
|
|
||||||
-1)
|
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze(2)
|
|
||||||
|
|
||||||
# dispatch data to GPU
|
|
||||||
if use_cuda:
|
|
||||||
text_input = text_input.cuda()
|
|
||||||
mel_input = mel_input.cuda()
|
|
||||||
mel_lengths = mel_lengths.cuda()
|
|
||||||
linear_input = linear_input.cuda() if c.model in [
|
|
||||||
"Tacotron", "TacotronGST"
|
|
||||||
] else None
|
|
||||||
stop_targets = stop_targets.cuda()
|
|
||||||
if speaker_ids is not None:
|
|
||||||
speaker_ids = speaker_ids.cuda()
|
|
||||||
|
|
||||||
# forward pass
|
|
||||||
decoder_output, postnet_output, alignments, stop_tokens =\
|
|
||||||
model.forward(text_input, text_lengths, mel_input,
|
|
||||||
speaker_ids=speaker_ids)
|
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(
|
stop_loss = criterion_st(
|
||||||
|
@ -413,6 +383,16 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
postnet_loss = criterion(postnet_output, mel_input)
|
postnet_loss = criterion(postnet_output, mel_input)
|
||||||
loss = decoder_loss + postnet_loss + stop_loss
|
loss = decoder_loss + postnet_loss + stop_loss
|
||||||
|
|
||||||
|
# backward decoder loss
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
if c.loss_masking:
|
||||||
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||||
|
else:
|
||||||
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||||
|
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||||
|
loss += decoder_backward_loss + decoder_c_loss
|
||||||
|
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
|
@ -433,7 +413,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
'avg_decoder_loss':
|
'avg_decoder_loss':
|
||||||
float(decoder_loss.item()),
|
float(decoder_loss.item()),
|
||||||
'avg_stop_loss':
|
'avg_stop_loss':
|
||||||
float(stop_loss.item())
|
float(stop_loss.item()),
|
||||||
})
|
})
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
|
@ -479,13 +459,25 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_backward']
|
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
||||||
epoch_figures['alignment_backward'] = alignments_backward[idx].data.cpu().numpy()
|
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||||
|
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
||||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||||
|
|
||||||
|
|
||||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
|
if c.test_sentences_file is None:
|
||||||
|
test_sentences = [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist."
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
with open(c.test_sentences_file, "r") as f:
|
||||||
|
test_sentences = [s.strip() for s in f.readlines()]
|
||||||
|
|
||||||
# test sentences
|
# test sentences
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
@ -630,6 +622,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
||||||
c.r = r
|
c.r = r
|
||||||
model.decoder.set_r(r)
|
model.decoder.set_r(r)
|
||||||
|
if c.bidirectional_decoder: model.decoder_backward.set_r(r)
|
||||||
print(" > Number of outputs per iteration:", model.decoder.r)
|
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||||
|
|
||||||
train_loss, global_step = train(model, criterion, criterion_st,
|
train_loss, global_step = train(model, criterion, criterion_st,
|
||||||
|
|
Loading…
Reference in New Issue