diff --git a/config.json b/config.json index 01860746..838f1510 100644 --- a/config.json +++ b/config.json @@ -12,18 +12,18 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.003, - "batch_size": 180, + "lr": 0.0006, + "warmup_steps": 4000, + "batch_size": 32, "r": 5, "griffin_lim_iters": 60, "power": 1.5, - "num_loader_workers": 32, + "num_loader_workers": 16, "checkpoint": false, "save_step": 69, - "data_path": "/data/shared/KeithIto/LJSpeech-1.0", - "output_path": "result", - "log_dir": "/home/erogol/projects/TTS/logs/" + "data_path": "/run/shm/erogol/LJSpeech-1.0", + "output_path": "result" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 3fe07078..a42a626e 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -1,4 +1,3 @@ -import pandas as pd import os import numpy as np import collections @@ -16,16 +15,18 @@ class LJSpeechDataset(Dataset): def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, text_cleaner, num_mels, min_level_db, frame_shift_ms, frame_length_ms, preemphasis, ref_level_db, num_freq, power): - self.frames = pd.read_csv(csv_file, sep='|', header=None) + + with open(csv_file, "r") as f: + self.frames = [line.split('|') for line in f] self.root_dir = root_dir self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate self.cleaners = text_cleaner self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power - ) + frame_length_ms, preemphasis, ref_level_db, num_freq, power) print(" > Reading LJSpeech from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.frames))) + self._sort_frames() def load_wav(self, filename): try: @@ -34,22 +35,44 @@ class LJSpeechDataset(Dataset): except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) + def _sort_frames(self): + r"""Sort sequences in ascending order""" + lengths = np.array([len(ins[1]) for ins in self.frames]) + + print(" | > Max length sequence {}".format(np.max(lengths))) + print(" | > Min length sequence {}".format(np.min(lengths))) + print(" | > Avg length sequence {}".format(np.mean(lengths))) + + idxs = np.argsort(lengths) + new_frames = [None] * len(lengths) + for i, idx in enumerate(idxs): + new_frames[i] = self.frames[idx] + self.frames = new_frames + def __len__(self): return len(self.frames) def __getitem__(self, idx): wav_name = os.path.join(self.root_dir, - self.frames.ix[idx, 0]) + '.wav' - text = self.frames.ix[idx, 1] + self.frames[idx][0]) + '.wav' + text = self.frames[idx][1] text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) - sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]} + sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} return sample def get_dummy_data(self): + r"""Get a dummy input for testing""" return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor) def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. PAD sequences with the longest sequence in the batch + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences that can be divided by r. + 4. Convert Numpy to Torch tensors. + """ # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.Mapping): diff --git a/layers/attention.py b/layers/attention.py index 8d993cea..e7385149 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -5,26 +5,27 @@ from torch.nn import functional as F class BahdanauAttention(nn.Module): - def __init__(self, dim): + def __init__(self, annot_dim, query_dim, hidden_dim): super(BahdanauAttention, self).__init__() - self.query_layer = nn.Linear(dim, dim, bias=False) - self.tanh = nn.Tanh() - self.v = nn.Linear(dim, 1, bias=False) + self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) + self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) + self.v = nn.Linear(hidden_dim, 1, bias=False) - def forward(self, query, processed_inputs): + def forward(self, annots, query): """ - Args: - query: (batch, 1, dim) or (batch, dim) - processed_inputs: (batch, max_time, dim) + Shapes: + - query: (batch, 1, dim) or (batch, dim) + - annots: (batch, max_time, dim) """ if query.dim() == 2: # insert time-axis for broadcasting query = query.unsqueeze(1) # (batch, 1, dim) processed_query = self.query_layer(query) + processed_annots = self.annot_layer(annots) # (batch, max_time, 1) - alignment = self.v(self.tanh(processed_query + processed_inputs)) + alignment = self.v(nn.functional.tanh(processed_query + processed_annots)) # (batch, max_time) return alignment.squeeze(-1) @@ -34,7 +35,7 @@ def get_mask_from_lengths(inputs, inputs_lengths): """Get mask tensor from list of length Args: - inputs: (batch, max_time, dim) + inputs: Tensor in size (batch, max_time, dim) inputs_lengths: array like """ mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() @@ -43,52 +44,48 @@ def get_mask_from_lengths(inputs, inputs_lengths): return ~mask -class AttentionWrapper(nn.Module): - def __init__(self, rnn_cell, alignment_model, +class AttentionRNN(nn.Module): + def __init__(self, out_dim, annot_dim, memory_dim, score_mask_value=-float("inf")): - super(AttentionWrapper, self).__init__() - self.rnn_cell = rnn_cell - self.alignment_model = alignment_model + super(AttentionRNN, self).__init__() + self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim) + self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) self.score_mask_value = score_mask_value - def forward(self, query, context_vec, cell_state, inputs, - processed_inputs=None, mask=None, inputs_lengths=None): + def forward(self, memory, context, rnn_state, annotations, + mask=None, annotations_lengths=None): - if processed_inputs is None: - processed_inputs = inputs - - if inputs_lengths is not None and mask is None: - mask = get_mask_from_lengths(inputs, inputs_lengths) + if annotations_lengths is not None and mask is None: + mask = get_mask_from_lengths(annotations, annotations_lengths) # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - # import ipdb - # ipdb.set_trace() - alignment = self.alignment_model(cell_state, processed_inputs) + alignment = self.alignment_model(annotations, rnn_state) + # TODO: needs recheck. if mask is not None: mask = mask.view(query.size(0), -1) alignment.data.masked_fill_(mask, self.score_mask_value) - # Normalize context_vec weight + # Normalize context weight alignment = F.softmax(alignment, dim=-1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j - context_vec = torch.bmm(alignment.unsqueeze(1), inputs) - context_vec = context_vec.squeeze(1) + context = torch.bmm(alignment.unsqueeze(1), annotations) + context = context.squeeze(1) - # Concat input query and previous context_vec context - cell_input = torch.cat((query, context_vec), -1) - #cell_input = cell_input.unsqueeze(1) + # Concat input query and previous context context + rnn_input = torch.cat((memory, context), -1) + #rnn_input = rnn_input.unsqueeze(1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) - cell_output = self.rnn_cell(cell_input, cell_state) + rnn_output = self.rnn_cell(rnn_input, rnn_state) - context_vec = context_vec.squeeze(1) - return cell_output, context_vec, alignment + context = context.squeeze(1) + return rnn_output, context, alignment diff --git a/layers/tacotron.py b/layers/tacotron.py index 9b31d02b..6f5926a8 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -3,7 +3,7 @@ import torch from torch.autograd import Variable from torch import nn -from .attention import BahdanauAttention, AttentionWrapper +from .attention import AttentionRNN from .attention import get_mask_from_lengths class Prenet(nn.Module): @@ -153,7 +153,7 @@ class CBHG(nn.Module): out = conv1d(x) out = out[:, :, :T] outs.append(out) - + x = torch.cat(outs, dim=1) assert x.size(1) == self.in_features * len(self.conv1d_banks) @@ -219,15 +219,10 @@ class Decoder(nn.Module): self.memory_dim = memory_dim self.eps = eps self.r = r - # input -> |Linear| -> processed_inputs - self.input_layer = nn.Linear(in_features, 256, bias=False) # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State - self.attention_rnn = AttentionWrapper( - nn.GRUCell(in_features + 128, 256), - BahdanauAttention(256) - ) + self.attention_rnn = AttentionRNN(256, in_features, 128) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256+in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -236,7 +231,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - def forward(self, inputs, memory=None, memory_lengths=None): + def forward(self, inputs, memory=None, input_lengths=None): r""" Decoder forward step. @@ -245,9 +240,9 @@ class Decoder(nn.Module): Args: inputs: Encoder outputs. - memory: Decoder memory (autoregression. If None (at eval-time), + memory (None): Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. - memory_lengths: Encoder output (memory) lengths. If not None, used for + input_lengths (None): input lengths, used for attention masking. Shapes: @@ -256,12 +251,11 @@ class Decoder(nn.Module): """ B = inputs.size(0) - # TODO: take this segment into Attention module. - processed_inputs = self.input_layer(inputs) - if memory_lengths is not None: - mask = get_mask_from_lengths(processed_inputs, memory_lengths) - else: - mask = None + + # if input_lengths is not None: + # mask = get_mask_from_lengths(processed_inputs, input_lengths) + # else: + # mask = None # Run greedy decoding if memory is None greedy = memory is None @@ -301,13 +295,14 @@ class Decoder(nn.Module): while True: if t > 0: memory_input = outputs[-1] if greedy else memory[t - 1] + # Prenet processed_memory = self.prenet(memory_input) # Attention RNN attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, - inputs, processed_inputs=processed_inputs, mask=mask) + inputs) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( @@ -350,5 +345,5 @@ class Decoder(nn.Module): return outputs, alignments -def is_end_of_frames(output, eps=0.1): #0.2 +def is_end_of_frames(output, eps=0.2): #0.2 return (output.data <= eps).all() diff --git a/models/tacotron.py b/models/tacotron.py index 6e67edf8..57c9b43d 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -9,11 +9,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG class Tacotron(nn.Module): def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, freq_dim=1025, r=5, padding_idx=None, - use_memory_mask=False): + use_atten_mask=False): super(Tacotron, self).__init__() self.mel_dim = mel_dim self.linear_dim = linear_dim - self.use_memory_mask = use_memory_mask + self.use_atten_mask = use_atten_mask self.embedding = nn.Embedding(len(symbols), embedding_dim, padding_idx=padding_idx) print(" | > Embedding dim : {}".format(len(symbols))) @@ -33,13 +33,12 @@ class Tacotron(nn.Module): # (B, T', in_dim) encoder_outputs = self.encoder(inputs) - if self.use_memory_mask: - memory_lengths = input_lengths - else: - memory_lengths = None + if not self.use_atten_mask: + input_lengths = None + # (B, T', mel_dim*r) mel_outputs, alignments = self.decoder( - encoder_outputs, mel_specs, memory_lengths=memory_lengths) + encoder_outputs, mel_specs, input_lengths=input_lengths) # Post net processing below diff --git a/tests/generic_utils_text.py b/tests/generic_utils_text.py new file mode 100644 index 00000000..0461d263 --- /dev/null +++ b/tests/generic_utils_text.py @@ -0,0 +1,38 @@ +import unittest +import torch as T + +from TTS.utils.generic_utils import save_checkpoint, save_best_model +from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder + +OUT_PATH = '/tmp/test.pth.tar' + +class ModelSavingTests(unittest.TestCase): + + def save_checkpoint_test(self): + # create a dummy model + model = Prenet(128, out_features=[256, 128]) + model = T.nn.DataParallel(layer) + + # save the model + save_checkpoint(model, None, 100, + OUTPATH, 1, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model']) + + def save_best_model_test(self): + # create a dummy model + model = Prenet(256, out_features=[256, 256]) + model = T.nn.DataParallel(layer) + + # save the model + best_loss = save_best_model(model, None, 0, + 100, OUT_PATH, + 10, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model']) diff --git a/train.py b/train.py index 0d432cce..8aa6567d 100644 --- a/train.py +++ b/train.py @@ -21,42 +21,275 @@ from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, - count_parameters) + count_parameters, check_update) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron + use_cuda = torch.cuda.is_available() +parser = argparse.ArgumentParser() +parser.add_argument('--restore_path', type=str, + help='Folder path to checkpoints', default=0) +parser.add_argument('--config_path', type=str, + help='path to config file for training',) +args = parser.parse_args() + +# setup output paths and read configs +c = load_config(args.config_path) +_ = os.path.dirname(os.path.realpath(__file__)) +OUT_PATH = os.path.join(_, c.output_path) +OUT_PATH = create_experiment_folder(OUT_PATH) +CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') +shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) + +# save config to tmp place to be loaded by subsequent modules. +file_name = str(os.getpid()) +tmp_path = os.path.join("/tmp/", file_name+'_tts') +pickle.dump(c, open(tmp_path, "wb")) + +# setup tensorboard +LOG_DIR = OUT_PATH +tb = SummaryWriter(LOG_DIR) + + +def signal_handler(signal, frame): + """Ctrl+C handler to remove empty experiment folder""" + print(" !! Pressed Ctrl+C !!") + remove_experiment_folder(OUT_PATH) + sys.exit(1) + + +def train(model, criterion, data_loader, optimizer, epoch): + model = model.train() + epoch_time = 0 + avg_linear_loss = 0 + avg_mel_loss = 0 + + print(" | > Epoch {}/{}".format(epoch, c.epochs)) + progbar = Progbar(len(data_loader.dataset) / c.batch_size) + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + + current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 + + # setup lr + current_lr = lr_decay(c.lr, current_step, c.warmup_steps) + for params_group in optimizer.param_groups: + params_group['lr'] = current_lr + + optimizer.zero_grad() + + # convert inputs to variables + text_input_var = Variable(text_input) + mel_spec_var = Variable(mel_input) + linear_spec_var = Variable(linear_input, volatile=True) + + # sort sequence by length for curriculum learning + # TODO: might be unnecessary + sorted_lengths, indices = torch.sort( + text_lengths.view(-1), dim=0, descending=True) + sorted_lengths = sorted_lengths.long().numpy() + text_input_var = text_input_var[indices] + mel_spec_var = mel_spec_var[indices] + linear_spec_var = linear_spec_var[indices] + + # dispatch data to GPU + if use_cuda: + text_input_var = text_input_var.cuda() + mel_spec_var = mel_spec_var.cuda() + linear_spec_var = linear_spec_var.cuda() + + # forward pass + mel_output, linear_output, alignments =\ + model.forward(text_input_var, mel_spec_var, + input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths))) + + # loss computation + mel_loss = criterion(mel_output, mel_spec_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec_var[: ,: ,:n_priority_freq]) + loss = mel_loss + linear_loss + + # backpass and check the grad norm + loss.backward() + grad_norm, skip_flag = check_update(model, 0.5, 100) + if skip_flag: + optimizer.zero_grad() + print(" | > Iteration skipped!!") + continue + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # update + progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('linear_loss', linear_loss.data[0]), + ('mel_loss', mel_loss.data[0]), + ('grad_norm', grad_norm)]) + + # Plot Training Iter Stats + tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step) + tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0], + current_step) + tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], + current_step) + tb.add_scalar('Params/GradNorm', grad_norm, current_step) + tb.add_scalar('Time/StepTime', step_time, current_step) + + if current_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, optimizer, linear_loss.data[0], + OUT_PATH, current_step, epoch) + + # Diagnostic visualizations + const_spec = linear_output[0].data.cpu().numpy() + gt_spec = linear_spec_var[0].data.cpu().numpy() + + const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) + gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) + tb.add_image('Visual/Reconstruction', const_spec, current_step) + tb.add_image('Visual/GroundTruth', gt_spec, current_step) + + align_img = alignments[0].data.cpu().numpy() + align_img = plot_alignment(align_img) + tb.add_image('Visual/Alignment', align_img, current_step) + + # Sample audio + audio_signal = linear_output[0].data.cpu().numpy() + data_loader.dataset.ap.griffin_lim_iters = 60 + audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T) + try: + tb.add_audio('SampleAudio', audio_signal, current_step, + sample_rate=c.sample_rate) + except: + print("\n > Error at audio signal on TB!!") + print(audio_signal.max()) + print(audio_signal.min()) + + + avg_linear_loss /= (num_iter + 1) + avg_mel_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + + # Plot Training Epoch Stats + tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step) + tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step) + tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('Time/EpochTime', epoch_time, epoch) + epoch_time = 0 + + return avg_linear_loss, current_step + + +def evaluate(model, criterion, data_loader, current_step): + model = model.train() + epoch_time = 0 + + print(" | > Validation") + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + progbar = Progbar(len(data_loader.dataset) / c.batch_size) + + avg_linear_loss = 0 + avg_mel_loss = 0 + + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + + # convert inputs to variables + text_input_var = Variable(text_input) + mel_spec_var = Variable(mel_input) + linear_spec_var = Variable(linear_input, volatile=True) + + # dispatch data to GPU + if use_cuda: + text_input_var = text_input_var.cuda() + mel_spec_var = mel_spec_var.cuda() + linear_spec_var = linear_spec_var.cuda() + + # forward pass + mel_output, linear_output, alignments =\ + model.forward(text_input_var, mel_spec_var) + + # loss computation + mel_loss = criterion(mel_output, mel_spec_var) + linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec_var[: ,: ,:n_priority_freq]) + loss = mel_loss + linear_loss + + step_time = time.time() - start_time + epoch_time += step_time + + # update + progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), + ('linear_loss', linear_loss.data[0]), + ('mel_loss', mel_loss.data[0])]) + + avg_linear_loss += linear_loss.data[0] + avg_mel_loss += mel_loss.data[0] + + # Diagnostic visualizations + idx = np.random.randint(mel_input.shape[0]) + const_spec = linear_output[idx].data.cpu().numpy() + gt_spec = linear_spec_var[idx].data.cpu().numpy() + align_img = alignments[idx].data.cpu().numpy() + + const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) + gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) + align_img = plot_alignment(align_img) + + tb.add_image('ValVisual/Reconstruction', const_spec, current_step) + tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) + tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) + + # Sample audio + audio_signal = linear_output[idx].data.cpu().numpy() + data_loader.dataset.ap.griffin_lim_iters = 60 + audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T) + try: + tb.add_audio('ValSampleAudio', audio_signal, current_step, + sample_rate=c.sample_rate) + except: + print(" | > Error at audio signal on TB!!") + print(audio_signal.max()) + print(audio_signal.min()) + + # compute average losses + avg_linear_loss /= (num_iter + 1) + avg_mel_loss /= (num_iter + 1) + avg_total_loss = avg_mel_loss + avg_linear_loss + + # Plot Learning Stats + tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) + tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) + tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) + return avg_linear_loss + + def main(args): - # setup output paths and read configs - c = load_config(args.config_path) - _ = os.path.dirname(os.path.realpath(__file__)) - OUT_PATH = os.path.join(_, c.output_path) - OUT_PATH = create_experiment_folder(OUT_PATH) - CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') - shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) - - # save config to tmp place to be loaded by subsequent modules. - file_name = str(os.getpid()) - tmp_path = os.path.join("/tmp/", file_name+'_tts') - pickle.dump(c, open(tmp_path, "wb")) - - # setup tensorboard - LOG_DIR = OUT_PATH - tb = SummaryWriter(LOG_DIR) - - # Ctrl+C handler to remove empty experiment folder - def signal_handler(signal, frame): - print(" !! Pressed Ctrl+C !!") - remove_experiment_folder(OUT_PATH) - sys.exit(1) - signal.signal(signal.SIGINT, signal_handler) - # Setup the dataset - dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), + train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), os.path.join(c.data_path, 'wavs'), c.r, c.sample_rate, @@ -71,204 +304,77 @@ def main(args): c.power ) - dataloader = DataLoader(dataset, batch_size=c.batch_size, - shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + train_loader = DataLoader(train_dataset, batch_size=c.batch_size, + shuffle=False, collate_fn=train_dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers, + pin_memory=True) + + val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), + os.path.join(c.data_path, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) + + val_loader = DataLoader(val_dataset, batch_size=c.batch_size, + shuffle=False, collate_fn=val_dataset.collate_fn, + drop_last=False, num_workers= 4, + pin_memory=True) - # setup the model model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, - c.r) - - # plot model on tensorboard - dummy_input = dataset.get_dummy_data() - - ## TODO: onnx does not support RNN fully yet - # model_proto_path = os.path.join(OUT_PATH, "model.proto") - # onnx.export(model, dummy_input, model_proto_path, verbose=True) - # tb.add_graph_onnx(model_proto_path) - - if use_cuda: - model = nn.DataParallel(model.cuda()) + c.r, + use_atten_mask=True) optimizer = optim.Adam(model.parameters(), lr=c.lr) - - if args.restore_step: - checkpoint = torch.load(os.path.join( - args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step)) - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("\n > Model restored from step %d\n" % args.restore_step) - start_epoch = checkpoint['step'] // len(dataloader) - best_loss = checkpoint['linear_loss'] - else: - start_epoch = 0 - print("\n > Starting a new training") - - num_params = count_parameters(model) - print(" | > Model has {} parameters".format(num_params)) - - model = model.train() - - if not os.path.exists(CHECKPOINT_PATH): - os.mkdir(CHECKPOINT_PATH) - + if use_cuda: criterion = nn.L1Loss().cuda() else: criterion = nn.L1Loss() - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + if args.restore_path: + checkpoint = torch.load(args.restore_path) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("\n > Model restored from step %d\n" % checkpoint['step']) + start_epoch = checkpoint['step'] // len(train_loader) + best_loss = checkpoint['linear_loss'] + start_epoch = 0 + args.restore_step = checkpoint['step'] + else: + args.restore_step = 0 + print("\n > Starting a new training") - #lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, - # patience=c.lr_patience, verbose=True) - epoch_time = 0 - best_loss = float('inf') + if use_cuda: + model = nn.DataParallel(model.cuda()) + + num_params = count_parameters(model) + print(" | > Model has {} parameters".format(num_params)) + + if not os.path.exists(CHECKPOINT_PATH): + os.mkdir(CHECKPOINT_PATH) + + if 'best_loss' not in locals(): + best_loss = float('inf') + for epoch in range(0, c.epochs): - - print("\n | > Epoch {}/{}".format(epoch, c.epochs)) - progbar = Progbar(len(dataset) / c.batch_size) - - for num_iter, data in enumerate(dataloader): - start_time = time.time() - - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - - current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 - - # setup lr - current_lr = lr_decay(c.lr, current_step) - for params_group in optimizer.param_groups: - params_group['lr'] = current_lr - - optimizer.zero_grad() - - # Add a single frame of zeros to Mel Specs for better end detection - #try: - # mel_input = np.concatenate((np.zeros( - # [c.batch_size, 1, c.num_mels], dtype=np.float32), - # mel_input[:, 1:, :]), axis=1) - #except: - # raise TypeError("not same dimension") - - # convert inputs to variables - text_input_var = Variable(text_input) - mel_spec_var = Variable(mel_input) - linear_spec_var = Variable(linear_input, volatile=True) - - # sort sequence by length. - # TODO: might be unnecessary - sorted_lengths, indices = torch.sort( - text_lengths.view(-1), dim=0, descending=True) - sorted_lengths = sorted_lengths.long().numpy() - - text_input_var = text_input_var[indices] - mel_spec_var = mel_spec_var[indices] - linear_spec_var = linear_spec_var[indices] - - if use_cuda: - text_input_var = text_input_var.cuda() - mel_spec_var = mel_spec_var.cuda() - linear_spec_var = linear_spec_var.cuda() - - mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var, - input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths))) - - mel_loss = criterion(mel_output, mel_spec_var) - #linear_loss = torch.abs(linear_output - linear_spec_var) - #linear_loss = 0.5 * \ - #torch.mean(linear_loss) + 0.5 * \ - #torch.mean(linear_loss[:, :n_priority_freq, :]) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ - + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[: ,: ,:n_priority_freq]) - loss = mel_loss + linear_loss - # loss = loss.cuda() - - loss.backward() - grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need - optimizer.step() - - step_time = time.time() - start_time - epoch_time += step_time - - progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), - ('linear_loss', linear_loss.data[0]), - ('mel_loss', mel_loss.data[0]), - ('grad_norm', grad_norm)]) - - # Plot Learning Stats - tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step) - tb.add_scalar('Loss/LinearLoss', linear_loss.data[0], - current_step) - tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], - current_step) - tb.add_scalar('Params/GradNorm', grad_norm, current_step) - tb.add_scalar('Time/StepTime', step_time, current_step) - - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - tb.add_image('Attn/Alignment', align_img, current_step) - - if current_step % c.save_step == 0: - - if c.checkpoint: - # save model - save_checkpoint(model, optimizer, linear_loss.data[0], - OUT_PATH, current_step, epoch) - - # Diagnostic visualizations - const_spec = linear_output[0].data.cpu().numpy() - gt_spec = linear_spec_var[0].data.cpu().numpy() - - const_spec = plot_spectrogram(const_spec, dataset.ap) - gt_spec = plot_spectrogram(gt_spec, dataset.ap) - tb.add_image('Spec/Reconstruction', const_spec, current_step) - tb.add_image('Spec/GroundTruth', gt_spec, current_step) - - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - tb.add_image('Attn/Alignment', align_img, current_step) - - # Sample audio - audio_signal = linear_output[0].data.cpu().numpy() - dataset.ap.griffin_lim_iters = 60 - audio_signal = dataset.ap.inv_spectrogram(audio_signal.T) - try: - tb.add_audio('SampleAudio', audio_signal, current_step, - sample_rate=c.sample_rate) - except: - print("\n > Error at audio signal on TB!!") - print(audio_signal.max()) - print(audio_signal.min()) - - - # average loss after the epoch - avg_epoch_loss = np.mean( - progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1])) - best_loss = save_best_model(model, optimizer, avg_epoch_loss, + train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch) + val_loss = evaluate(model, criterion, val_loader, current_step) + best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch) - #lr_scheduler.step(loss.data[0]) - tb.add_scalar('Time/EpochTime', epoch_time, epoch) - epoch_time = 0 - - if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--restore_step', type=int, - help='Global step to restore checkpoint', default=0) - parser.add_argument('--restore_path', type=str, - help='Folder path to checkpoints', default=0) - parser.add_argument('--config_path', type=str, - help='path to config file for training',) - args = parser.parse_args() + signal.signal(signal.SIGINT, signal_handler) main(args) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ca32060c..4832ec44 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -7,6 +7,7 @@ import datetime import json import torch import numpy as np +from collections import OrderedDict class AttrDict(dict): @@ -94,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, return best_loss -def lr_decay(init_lr, global_step): - warmup_steps = 4000.0 +def check_update(model, grad_clip, grad_top): + r'''Check model gradient against unexpected jumps and failures''' + skip_flag = False + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip) + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + elif grad_norm > grad_top: + print(" | > Gradient is above the top limit !!") + skip_flag = True + return grad_norm, skip_flag + + +def lr_decay(init_lr, global_step, warmup_steps): + r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py''' step = global_step + 1. lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) @@ -197,13 +211,13 @@ class Progbar(object): eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format + + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) else: - if time_per_unit >= 1: - info += ' %.0fs/step' % time_per_unit - elif time_per_unit >= 1e-3: - info += ' %.0fms/step' % (time_per_unit * 1e3) - else: - info += ' %.0fus/step' % (time_per_unit * 1e6) + info += ' %.0fus/step' % (time_per_unit * 1e6) for k in self.unique_values: info += ' - %s:' % k diff --git a/utils/visual.py b/utils/visual.py index 935ed369..b0143fc9 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt def plot_alignment(alignment, info=None): - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=(16,10)) im = ax.imshow(alignment.T, aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax)