diff --git a/.gitignore b/.gitignore index b6fee485..579bfbea 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ TODO.txt .vscode/* data/* notebooks/data/* +TTS/tts/layers/glow_tts/monotonic_align/core.c diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py new file mode 100644 index 00000000..cf9d98d2 --- /dev/null +++ b/TTS/bin/train_glow_tts.py @@ -0,0 +1,649 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import os +import sys +import time +import traceback + +import numpy as np +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.layers.losses import GlowTTSLoss +from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, + reduce_tensor) +from TTS.tts.utils.generic_utils import check_config, setup_model +from TTS.tts.utils.io import save_best_model, save_checkpoint +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, + save_speaker_mapping) +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.generic_utils import (KeepAverage, count_parameters, + create_experiment_folder, get_git_branch, + remove_experiment_folder, set_init_dict) +from TTS.utils.io import copy_config_file, load_config +from TTS.utils.radam import RAdam +from TTS.utils.tensorboard_logger import TensorboardLogger +from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, + gradual_training_scheduler, set_weight_decay, + setup_torch_training_env) + +use_cuda, num_gpus = setup_torch_training_env(True, False) + + +def setup_loader(ap, r, is_val=False, verbose=False): + if is_val and not c.run_eval: + loader = None + else: + dataset = MyDataset( + r, + c.text_cleaner, + compute_linear_spec=True if c.model.lower() == 'tacotron' else False, + meta_data=meta_data_eval if is_val else meta_data_train, + ap=ap, + tp=c.characters if 'characters' in c.keys() else None, + batch_group_size=0 if is_val else c.batch_group_size * + c.batch_size, + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + phoneme_cache_path=c.phoneme_cache_path, + use_phonemes=c.use_phonemes, + phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, + verbose=verbose) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=c.eval_batch_size if is_val else c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers + if is_val else c.num_loader_workers, + pin_memory=False) + return loader + + +def format_data(data): + if c.use_speaker_embedding: + speaker_mapping = load_speaker_mapping(OUT_PATH) + + # setup input data + text_input = data[0] + text_lengths = data[1] + speaker_names = data[2] + mel_input = data[4].permute(0, 2, 1) # B x D x T + mel_lengths = data[5] + attn_mask = data[8] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + if c.use_speaker_embedding: + speaker_ids = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_ids = torch.LongTensor(speaker_ids) + else: + speaker_ids = None + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ + avg_text_length, avg_spec_length, attn_mask + + +def data_depended_init(model, ap): + """Data depended initialization for activation normalization.""" + if hasattr(model, 'module'): + for f in model.module.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + else: + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + data_loader = setup_loader(ap, 1, is_val=False) + model.train() + print(" > Data depended initialization ... ") + with torch.no_grad(): + for num_iter, data in enumerate(data_loader): + + # format data + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ + avg_text_length, avg_spec_length, attn_mask = format_data(data) + + # forward pass model + _ = model.forward( + text_input, text_lengths, mel_input, mel_lengths, attn_mask) + break + + if hasattr(model, 'module'): + for f in model.module.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + else: + for f in model.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + return model + + +def train(model, criterion, optimizer, scheduler, + ap, global_step, epoch, amp): + data_loader = setup_loader(ap, 1, is_val=False, + verbose=(epoch == 0)) + model.train() + epoch_time = 0 + keep_avg = KeepAverage() + if use_cuda: + batch_n_iter = int( + len(data_loader.dataset) / (c.batch_size * num_gpus)) + else: + batch_n_iter = int(len(data_loader.dataset) / c.batch_size) + end_time = time.time() + c_logger.print_train_start() + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ + avg_text_length, avg_spec_length, attn_mask = format_data(data) + + loader_time = time.time() - end_time + + global_step += 1 + + # setup lr + if c.noam_schedule: + scheduler.step() + optimizer.zero_grad() + + # forward pass model + z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( + text_input, text_lengths, mel_input, mel_lengths, attn_mask) + + # compute loss + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, + o_dur_log, o_total_dur, text_lengths) + + # backward pass + if amp is not None: + with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss_dict['loss'].backward() + + if amp: + amp_opt_params = amp.master_params(optimizer) + else: + amp_opt_params = None + grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) + optimizer.step() + + # current_lr + current_lr = optimizer.param_groups[0]['lr'] + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(alignments) + loss_dict['align_error'] = align_error + + step_time = time.time() - start_time + epoch_time += step_time + + # aggregate losses from processes + if num_gpus > 1: + loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) + loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) + loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + update_train_values['avg_loader_time'] = loader_time + update_train_values['avg_step_time'] = step_time + keep_avg.update_values(update_train_values) + + # print training progress + if global_step % c.print_step == 0: + log_dict = { + "avg_spec_length": [avg_spec_length, 1], # value, precision + "avg_text_length": [avg_text_length, 1], + "step_time": [step_time, 4], + "loader_time": [loader_time, 2], + "current_lr": current_lr, + } + c_logger.print_train_step(batch_n_iter, num_iter, global_step, + log_dict, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # Plot Training Iter Stats + # reduce TB load + if global_step % c.tb_plot_step == 0: + iter_stats = { + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time + } + iter_stats.update(loss_dict) + tb_logger.tb_train_iter_stats(global_step, iter_stats) + + if global_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, + model_loss=loss_dict['loss'], + amp_state_dict=amp.state_dict() if amp else None) + + # Diagnostic visualizations + # direct pass on model for spec predictions + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred = spec_pred.permute(0, 2, 1) + gt_spec = mel_input.permute(0, 2, 1) + const_spec = spec_pred[0].data.cpu().numpy() + gt_spec = gt_spec[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img), + } + + tb_logger.tb_train_figures(global_step, figures) + + # Sample audio + train_audio = ap.inv_melspectrogram(const_spec.T) + tb_logger.tb_train_audios(global_step, + {'TrainAudio': train_audio}, + c.audio["sample_rate"]) + end_time = time.time() + + # print epoch stats + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) + + # Plot Epoch Stats + if args.rank == 0: + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(keep_avg.avg_values) + tb_logger.tb_train_epoch_stats(global_step, epoch_stats) + if c.tb_model_param_stats: + tb_logger.tb_model_weights(model, global_step) + return keep_avg.avg_values, global_step + + +@torch.no_grad() +def evaluate(model, criterion, ap, global_step, epoch): + data_loader = setup_loader(ap, 1, is_val=True) + model.eval() + epoch_time = 0 + keep_avg = KeepAverage() + c_logger.print_eval_start() + if data_loader is not None: + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ + avg_text_length, avg_spec_length, attn_mask = format_data(data) + + # forward pass model + z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( + text_input, text_lengths, mel_input, mel_lengths, attn_mask) + + # compute loss + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, + o_dur_log, o_total_dur, text_lengths) + + # step time + step_time = time.time() - start_time + epoch_time += step_time + + # compute alignment score + align_error = 1 - alignment_diagonal_score(alignments) + loss_dict['align_error'] = align_error + + # aggregate losses from processes + if num_gpus > 1: + loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) + loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) + loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + keep_avg.update_values(update_train_values) + + if c.print_eval: + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # Diagnostic visualizations + # direct pass on model for spec predictions + if hasattr(model, 'module'): + spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1]) + else: + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) + spec_pred = spec_pred.permute(0, 2, 1) + gt_spec = mel_input.permute(0, 2, 1) + + const_spec = spec_pred[0].data.cpu().numpy() + gt_spec = gt_spec[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + eval_figures = { + "prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img) + } + + # Sample audio + eval_audio = ap.inv_melspectrogram(const_spec.T) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, + c.audio["sample_rate"]) + + # Plot Validation Stats + tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) + tb_logger.tb_eval_figures(global_step, eval_figures) + + if args.rank == 0 and epoch >= c.test_delay_epochs: + if c.test_sentences_file is None: + test_sentences = [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963." + ] + else: + with open(c.test_sentences_file, "r") as f: + test_sentences = [s.strip() for s in f.readlines()] + + # test sentences + test_audios = {} + test_figures = {} + print(" | > Synthesizing test sentences") + speaker_id = 0 if c.use_speaker_embedding else None + style_wav = c.get("style_wav_for_test") + for idx, test_sentence in enumerate(test_sentences): + try: + wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis( + model, + test_sentence, + c, + use_cuda, + ap, + speaker_id=speaker_id, + style_wav=style_wav, + truncated=False, + enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + use_griffin_lim=True, + do_trim_silence=False) + + file_path = os.path.join(AUDIO_PATH, str(global_step)) + os.makedirs(file_path, exist_ok=True) + file_path = os.path.join(file_path, + "TestSentence_{}.wav".format(idx)) + ap.save_wav(wav, file_path) + test_audios['{}-audio'.format(idx)] = wav + test_figures['{}-prediction'.format(idx)] = plot_spectrogram( + postnet_output, ap) + test_figures['{}-alignment'.format(idx)] = plot_alignment( + alignment) + except: + print(" !! Error creating Test Sentence -", idx) + traceback.print_exc() + tb_logger.tb_test_audios(global_step, test_audios, + c.audio['sample_rate']) + tb_logger.tb_test_figures(global_step, test_figures) + return keep_avg.avg_values + + +# FIXME: move args definition/parsing inside of main? +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train, meta_data_eval, symbols, phonemes + # Audio processor + ap = AudioProcessor(**c.audio) + if 'characters' in c.keys(): + symbols, phonemes = make_symbols(**c.characters) + + # DISTRUBUTED + if num_gpus > 1: + init_distributed(args.rank, num_gpus, args.group_id, + c.distributed["backend"], c.distributed["url"]) + num_chars = len(phonemes) if c.use_phonemes else len(symbols) + + # load data instances + meta_data_train, meta_data_eval = load_meta_data(c.datasets) + + # set the portion of the data used for training + if 'train_portion' in c.keys(): + meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] + if 'eval_portion' in c.keys(): + meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + + # parse speakers + if c.use_speaker_embedding: + speakers = get_speakers(meta_data_train) + if args.restore_path: + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + assert all([speaker in speaker_mapping + for speaker in speakers]), "As of now you, you cannot " \ + "introduce new speakers to " \ + "a previously trained model." + else: + speaker_mapping = {name: i for i, name in enumerate(speakers)} + save_speaker_mapping(OUT_PATH, speaker_mapping) + num_speakers = len(speaker_mapping) + print("Training with {} speakers: {}".format(num_speakers, + ", ".join(speakers))) + else: + num_speakers = 0 + + # setup model + model = setup_model(num_chars, num_speakers, c) + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) + criterion = GlowTTSLoss() + + if c.apex_amp_level: + # pylint: disable=import-outside-toplevel + from apex import amp + from apex.parallel import DistributedDataParallel as DDP + model.cuda() + model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) + else: + amp = None + + if args.restore_path: + checkpoint = torch.load(args.restore_path, map_location='cpu') + try: + # TODO: fix optimizer init, model.cuda() needs to be called before + # optimizer restore + optimizer.load_state_dict(checkpoint['optimizer']) + if c.reinit_layers: + raise RuntimeError + model.load_state_dict(checkpoint['model']) + except: + print(" > Partial model initialization.") + model_dict = model.state_dict() + model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model.load_state_dict(model_dict) + del model_dict + + if amp and 'amp' in checkpoint: + amp.load_state_dict(checkpoint['amp']) + + for group in optimizer.param_groups: + group['initial_lr'] = c.lr + print(" > Model restored from step %d" % checkpoint['step'], + flush=True) + args.restore_step = checkpoint['step'] + else: + args.restore_step = 0 + + if use_cuda: + model.cuda() + criterion.cuda() + + # DISTRUBUTED + if num_gpus > 1: + model = DDP(model) + + if c.noam_schedule: + scheduler = NoamLR(optimizer, + warmup_steps=c.warmup_steps, + last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if 'best_loss' not in locals(): + best_loss = float('inf') + + global_step = args.restore_step + model = data_depended_init(model, ap) + for epoch in range(0, c.epochs): + c_logger.print_epoch_start(epoch, c.epochs) + train_avg_loss_dict, global_step = train(model, criterion, optimizer, + scheduler, ap, global_step, + epoch, amp) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + c_logger.print_epoch_end(epoch, eval_avg_loss_dict) + target_loss = train_avg_loss_dict['avg_loss'] + if c.run_eval: + target_loss = eval_avg_loss_dict['avg_loss'] + best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, + OUT_PATH, amp_state_dict=amp.state_dict() if amp else None) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--continue_path', + type=str, + help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', + default='', + required='--config_path' not in sys.argv) + parser.add_argument( + '--restore_path', + type=str, + help='Model file to be restored. Use to finetune a model.', + default='') + parser.add_argument( + '--config_path', + type=str, + help='Path to config file for training.', + required='--continue_path' not in sys.argv + ) + parser.add_argument('--debug', + type=bool, + default=False, + help='Do not verify commit integrity to run training.') + + # DISTRUBUTED + parser.add_argument( + '--rank', + type=int, + default=0, + help='DISTRIBUTED: process rank for distributed training.') + parser.add_argument('--group_id', + type=str, + default="", + help='DISTRIBUTED: process group id.') + args = parser.parse_args() + + if args.continue_path != '': + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, 'config.json') + list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv + latest_model_file = max(list_of_files, key=os.path.getctime) + args.restore_path = latest_model_file + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + # check_config(c) + _ = os.path.dirname(os.path.realpath(__file__)) + + if c.apex_amp_level: + print(" > apex AMP level: ", c.apex_amp_level) + + OUT_PATH = args.continue_path + if args.continue_path == '': + OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) + + AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(AUDIO_PATH, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_config_file(args.config_path, + os.path.join(OUT_PATH, 'config.json'), new_fields) + os.chmod(AUDIO_PATH, 0o775) + os.chmod(OUT_PATH, 0o775) + + LOG_DIR = OUT_PATH + tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS') + + # write model desc to tensorboard + tb_logger.tb_add_text('model-description', c['run_description'], 0) + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/tts/configs/glow_tts_gated_conv.json b/TTS/tts/configs/glow_tts_gated_conv.json new file mode 100644 index 00000000..696bdaf7 --- /dev/null +++ b/TTS/tts/configs/glow_tts_gated_conv.json @@ -0,0 +1,132 @@ +{ + "model": "glow_tts", + "run_name": "glow-tts-gatedconv", + "run_description": "glow-tts model training with gated conv.", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Griffin-Lim + "power": 1.1, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // MODEL PARAMETERS + "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. + + // TRAINING + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":16, + "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": true, // use noam warmup and lr schedule. + "grad_clip": 5.0, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate. + "wd": 0.000001, // Weight decay weight. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. + + "encoder_type": "gatedconv", + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "apex_amp_level": null, + + // DATA LOADING + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 500, // DATASET-RELATED: maximum text length + "compute_f0": false, // compute f0 values in data-loader + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/", + + // PHONEMES + "phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. + "use_gst": false, // TACOTRON ONLY: use global style tokens + + // DATASETS + "datasets": // List of datasets. They all merged and they get different speaker_ids. + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null + // "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/" + } + ] +} + + diff --git a/TTS/tts/configs/glow_tts_tdsep.json b/TTS/tts/configs/glow_tts_tdsep.json new file mode 100644 index 00000000..67047523 --- /dev/null +++ b/TTS/tts/configs/glow_tts_tdsep.json @@ -0,0 +1,132 @@ +{ + "model": "glow_tts", + "run_name": "glow-tts-tdsep-conv", + "run_description": "glow-tts model training with time-depth separable conv encoder.", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Griffin-Lim + "power": 1.1, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // MODEL PARAMETERS + "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. + + // TRAINING + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":16, + "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": true, // use noam warmup and lr schedule. + "grad_clip": 5.0, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate. + "wd": 0.000001, // Weight decay weight. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. + + "encoder_type": "time-depth-separable", + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "apex_amp_level": null, + + // DATA LOADING + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 500, // DATASET-RELATED: maximum text length + "compute_f0": false, // compute f0 values in data-loader + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/", + + // PHONEMES + "phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. + "use_gst": false, // TACOTRON ONLY: use global style tokens + + // DATASETS + "datasets": // List of datasets. They all merged and they get different speaker_ids. + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null + // "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/" + } + ] + } + + diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 9c50cb6a..a92b880f 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -113,7 +113,14 @@ class MyDataset(Dataset): return phonemes def load_data(self, idx): - text, wav_file, speaker_name = self.items[idx] + item = self.items[idx] + + if len(item) == 4: + text, wav_file, speaker_name, attn_file = item + else: + text, wav_file, speaker_name = item + attn = None + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) if self.use_phonemes: @@ -125,9 +132,13 @@ class MyDataset(Dataset): assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] + if "attn_file" in locals(): + attn = np.load(attn_file) + sample = { 'text': text, 'wav': wav, + 'attn': attn, 'item_idx': self.items[idx][1], 'speaker_name': speaker_name, 'wav_file_name': os.path.basename(wav_file) @@ -245,8 +256,21 @@ class MyDataset(Dataset): linear = torch.FloatTensor(linear).contiguous() else: linear = None + + # collate attention alignments + if batch[0]['attn'] is not None: + attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = text.shape[1] - attn.shape[0] + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + else: + attns = None return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs, speaker_embedding + stop_targets, item_idxs, speaker_embedding, attns raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/TTS/tts/layers/glow_tts/__init__.py b/TTS/tts/layers/glow_tts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py new file mode 100644 index 00000000..43811821 --- /dev/null +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -0,0 +1,111 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock +from TTS.tts.layers.glow_tts.normalization import ActNorm + + +def squeeze(x, x_mask=None, num_sqz=2): + b, c, t = x.size() + + t = (t // num_sqz) * num_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // num_sqz, num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, + 2).contiguous().view(b, c * num_sqz, t // num_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, num_sqz - 1::num_sqz] + else: + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, + dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, num_sqz=2): + b, c, t = x.size() + + x_unsqz = x.view(b, num_sqz, c // num_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, + 1).contiguous().view(b, c // num_sqz, + t * num_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, + num_sqz).view(b, 1, t * num_sqz) + else: + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, + dtype=x.dtype) + return x_unsqz * x_mask, x_mask + + +class Decoder(nn.Module): + """Stack of Glow Modules""" + def __init__(self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0., + num_splits=4, + num_sqz=2, + sigmoid_scale=False, + c_in_channels=0, + feat_channels=None): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_flow_blocks = num_flow_blocks + self.num_coupling_layers = num_coupling_layers + self.dropout_p = dropout_p + self.num_splits = num_splits + self.num_sqz = num_sqz + self.sigmoid_scale = sigmoid_scale + self.c_in_channels = c_in_channels + + self.flows = nn.ModuleList() + for _ in range(num_flow_blocks): + self.flows.append(ActNorm(channels=in_channels * num_sqz)) + self.flows.append( + InvConvNear(channels=in_channels * num_sqz, + num_splits=num_splits)) + self.flows.append( + CouplingBlock(in_channels * num_sqz, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale)) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + flows = self.flows + logdet_tot = 0 + else: + flows = reversed(self.flows) + logdet_tot = None + + if self.num_sqz > 1: + x, x_mask = squeeze(x, x_mask, self.num_sqz) + for f in flows: + if not reverse: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + logdet_tot += logdet + else: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + if self.num_sqz > 1: + x, x_mask = unsqueeze(x, x_mask, self.num_sqz) + return x, logdet_tot + + def store_inverse(self): + for f in self.flows: + f.store_inverse() diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py new file mode 100644 index 00000000..9f825832 --- /dev/null +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + +from .normalization import LayerNorm + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, dropout_p): + super().__init__() + # class arguments + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + # layers + self.drop = nn.Dropout(dropout_p) + self.conv_1 = nn.Conv1d(in_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + # output layer + self.proj = nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py new file mode 100644 index 00000000..df0e0462 --- /dev/null +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -0,0 +1,145 @@ +import math +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.transformer import Transformer +from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock + + +class Encoder(nn.Module): + """Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding + as in the original paper or GatedConvBlock as a faster alternative. + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + hidden_channels (int): encoder's embedding size. + filter_channels (int): transformer's feed-forward channels. + num_head (int): number of attention heads in transformer. + num_layers (int): number of transformer encoder stack. + kernel_size (int): kernel size for conv layers and duration predictor. + dropout_p (float): dropout rate for any dropout layer. + mean_only (bool): if True, output only mean values and use constant std. + use_prenet (bool): if True, use pre-convolutional layers before transformer layers. + c_in_channels (int): number of channels in conditional input. + + Shapes: + - input: (B, T, C) + """ + def __init__(self, + num_chars, + out_channels, + hidden_channels, + filter_channels, + filter_channels_dp, + encoder_type, + num_heads, + num_layers, + kernel_size, + dropout_p, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + use_prenet=True, + c_in_channels=0): + super().__init__() + # class arguments + self.num_chars = num_chars + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.filter_channels_dp = filter_channels_dp + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.mean_only = mean_only + self.use_prenet = use_prenet + self.c_in_channels = c_in_channels + self.encoder_type = encoder_type + # embedding layer + self.emb = nn.Embedding(num_chars, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + # init encoder + if encoder_type.lower() == "transformer": + # optional convolutional prenet + if use_prenet: + self.pre = ConvLayerNorm(hidden_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) + # text encoder + self.encoder = Transformer( + hidden_channels, + filter_channels, + num_heads, + num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + rel_attn_window_size=rel_attn_window_size, + input_length=input_length) + elif encoder_type.lower() == 'gatedconv': + self.encoder = GatedConvBlock(hidden_channels, + kernel_size=5, + dropout_p=dropout_p, + num_layers=3 + num_layers) + elif encoder_type.lower() == 'time-depth-separable': + # optional convolutional prenet + if use_prenet: + self.pre = ConvLayerNorm(hidden_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) + self.encoder = TimeDepthSeparableConvBlock(hidden_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3 + num_layers) + + # final projection layers + self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) + if not mean_only: + self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) + # duration predictor + self.duration_predictor = DurationPredictor( + hidden_channels + c_in_channels, filter_channels_dp, kernel_size, + dropout_p) + + def forward(self, x, x_lengths, g=None): + # embedding layer + # [B ,T, D] + x = self.emb(x) * math.sqrt(self.hidden_channels) + # [B, D, T] + x = torch.transpose(x, 1, -1) + # compute input sequence mask + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + # pre-conv layers + if self.encoder_type in ['transformer', 'time-depth-separable']: + if self.use_prenet: + x = self.pre(x, x_mask) + # encoder + x = self.encoder(x, x_mask) + # set duration predictor input + if g is not None: + g_exp = g.expand(-1, -1, x.size(-1)) + x_dp = torch.cat([torch.detach(x), g_exp], 1) + else: + x_dp = torch.detach(x) + # final projection layer + x_m = self.proj_m(x) * x_mask + if not self.mean_only: + x_logs = self.proj_s(x) * x_mask + else: + x_logs = torch.zeros_like(x_m) + # duration predictor + logw = self.duration_predictor(x_dp, x_mask) + return x_m, x_logs, logw, x_mask diff --git a/TTS/tts/layers/glow_tts/gated_conv.py b/TTS/tts/layers/glow_tts/gated_conv.py new file mode 100644 index 00000000..2417ea63 --- /dev/null +++ b/TTS/tts/layers/glow_tts/gated_conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn + +from .normalization import LayerNorm + + +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [ + nn.Conv1d(in_out_channels, + 2 * in_out_channels, + kernel_size, + padding=kernel_size // 2) + ] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, + p=self.dropout_p, + training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o \ No newline at end of file diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py new file mode 100644 index 00000000..acfc55e5 --- /dev/null +++ b/TTS/tts/layers/glow_tts/glow.py @@ -0,0 +1,334 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .normalization import LayerNorm + + +class ConvLayerNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, + num_layers, dropout_p): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout_p = dropout_p + assert num_layers > 1, " [!] number of layers should be > 0." + assert kernel_size % 2 == 1, " [!] kernel size should be odd number." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + + self.conv_layers.append( + nn.Conv1d(in_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + + for _ in range(num_layers - 1): + self.conv_layers.append( + nn.Conv1d(hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_res = x + for i in range(self.num_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x * x_mask) + x = F.dropout(F.relu(x), self.dropout_p, training=self.training) + x = x_res + self.proj(x) + return x * x_mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WN(torch.nn.Module): + def __init__(self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + assert hidden_channels % 2 == 0 + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.dropout = nn.Dropout(dropout_p) + + if c_in_channels != 0: + cond_layer = torch.nn.Conv1d(c_in_channels, + 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, + name='weight') + + for i in range(num_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + if i < num_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, + res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, + name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.num_layers): + x_in = self.in_layers[i](x) + x_in = self.dropout(x_in) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, + cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, + n_channels_tensor) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.num_layers - 1: + x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.c_in_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, + dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( + dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to( + dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) + + +class InvConvNear(nn.Module): + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + assert num_splits % 2 == 0 + self.channels = channels + self.num_splits = num_splits + self.no_jacobian = no_jacobian + self.weight_inv = None + + w_init = torch.qr( + torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + """Split the input into groups of size self.num_splits and + perform 1x1 convolution separately. Cast 1x1 conv operation + to 2d by reshaping the input for efficienty. + """ + + b, c, t = x.size() + assert c % self.num_splits == 0 + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, + c // self.num_splits, t) + + if reverse: + if self.weight_inv is not None: + weight = self.weight_inv + else: + weight = torch.inverse( + self.weight.float()).to(dtype=self.weight.dtype) + logdet = None + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + else: + logdet = torch.logdet( + self.weight) * (c / self.num_splits) * x_len # [b] + + weight = weight.view(self.num_splits, self.num_splits, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + return z, logdet + + def store_inverse(self): + self.weight_inv = torch.inverse( + self.weight.float()).to(dtype=self.weight.dtype) + + +class CouplingBlock(nn.Module): + def __init__(self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + self.sigmoid_scale = sigmoid_scale + + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + start = torch.nn.utils.weight_norm(start) + self.start = start + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, + num_layers, c_in_channels, dropout_p) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + + x = self.start(x_0) * x_mask + x = self.wn(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + m = out[:, :self.in_channels // 2, :] + logs = out[:, self.in_channels // 2:, :] + if self.sigmoid_scale: + logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) + + if reverse: + z_1 = (x_1 - m) * torch.exp(-logs) * x_mask + logdet = None + else: + z_1 = (m + torch.exp(logs) * x_1) * x_mask + logdet = torch.sum(logs * x_mask, [1, 2]) + + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + self.wn.remove_weight_norm() diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py new file mode 100644 index 00000000..267fb7f4 --- /dev/null +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -0,0 +1,49 @@ +import numpy as np +import torch +from torch.nn import functional as F +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + """ + duration: [b, t_x] + mask: [b, t_x, t_y] + """ + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] + ]))[:, :-1] + path = path * mask + return path + + +def maximum_path(value, mask): + """ Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/TTS/tts/layers/glow_tts/monotonic_align/core.pyx b/TTS/tts/layers/glow_tts/monotonic_align/core.pyx new file mode 100644 index 00000000..6aabccc4 --- /dev/null +++ b/TTS/tts/layers/glow_tts/monotonic_align/core.pyx @@ -0,0 +1,45 @@ +import numpy as np +cimport numpy as np +cimport cython +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/TTS/tts/layers/glow_tts/monotonic_align/setup.py b/TTS/tts/layers/glow_tts/monotonic_align/setup.py new file mode 100644 index 00000000..30c22480 --- /dev/null +++ b/TTS/tts/layers/glow_tts/monotonic_align/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup +from Cython.Build import cythonize +import numpy + +setup( + name = 'monotonic_align', + ext_modules = cythonize("core.pyx"), + include_dirs=[numpy.get_include()] +) diff --git a/TTS/tts/layers/glow_tts/normalization.py b/TTS/tts/layers/glow_tts/normalization.py new file mode 100644 index 00000000..70444abc --- /dev/null +++ b/TTS/tts/layers/glow_tts/normalization.py @@ -0,0 +1,101 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + """Layer norm for the 2nd dimension of the input. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1) + self.beta = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x): + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean)**2, 1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + x = x * self.gamma + self.beta + return x + + +class TemporalBatchNorm1d(nn.BatchNorm1d): + """Normalize each channel separately over time and batch. + """ + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + + def forward(self, x): + return super().forward(x.transpose(2,1)).transpose(2,1) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, + dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( + dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to( + dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) \ No newline at end of file diff --git a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py new file mode 100644 index 00000000..19fc7035 --- /dev/null +++ b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py @@ -0,0 +1,94 @@ +import torch +from torch import nn + +from .normalization import LayerNorm + + +class TimeDepthSeparableConv(nn.Module): + """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf + It shows competative results with less computation and memory footprint.""" + def __init__(self, + in_channels, + hid_channels, + out_channels, + kernel_size, + bias=True): + super(TimeDepthSeparableConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hid_channels = hid_channels + self.kernel_size = kernel_size + + self.time_conv = nn.Conv1d( + in_channels, + 2 * hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm1 = nn.BatchNorm1d(2 * hid_channels) + self.depth_conv = nn.Conv1d( + hid_channels, + hid_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=hid_channels, + bias=bias, + ) + self.norm2 = nn.BatchNorm1d(hid_channels) + self.time_conv2 = nn.Conv1d( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm3 = nn.BatchNorm1d(out_channels) + + def forward(self, x): + x_res = x + x = self.time_conv(x) + x = self.norm1(x) + x = nn.functional.glu(x, dim=1) + x = self.depth_conv(x) + x = self.norm2(x) + x = x * torch.sigmoid(x) + x = self.time_conv2(x) + x = self.norm3(x) + x = x_res + x + return x + + +class TimeDepthSeparableConvBlock(nn.Module): + def __init__(self, + in_channels, + hid_channels, + out_channels, + num_layers, + kernel_size, + bias=True): + super(TimeDepthSeparableConvBlock, self).__init__() + assert (kernel_size - 1) % 2 == 0 + assert num_layers > 1 + + self.layers = nn.ModuleList() + layer = TimeDepthSeparableConv( + in_channels, hid_channels, + out_channels if num_layers == 1 else hid_channels, kernel_size, + bias) + self.layers.append(layer) + for idx in range(num_layers - 1): + layer = TimeDepthSeparableConv( + hid_channels, hid_channels, out_channels if + (idx + 1) == (num_layers - 1) else hid_channels, kernel_size, + bias) + self.layers.append(layer) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x * mask) + return x diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py new file mode 100644 index 00000000..5cccea19 --- /dev/null +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -0,0 +1,320 @@ +import copy +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.glow_tts.glow import LayerNorm + + +class RelativePositionMultiHeadAttention(nn.Module): + """Implementation of Relative Position Encoding based on + https://arxiv.org/pdf/1809.04281.pdf + """ + def __init__(self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0., + input_length=None, + proximal_bias=False, + proximal_init=False): + super().__init__() + assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." + # class attributes + self.channels = channels + self.out_channels = out_channels + self.num_heads = num_heads + self.rel_attn_window_size = rel_attn_window_size + self.heads_share = heads_share + self.input_length = input_length + self.proximal_bias = proximal_bias + self.dropout_p = dropout_p + self.attn = None + # query, key, value layers + self.k_channels = channels // num_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + # output layers + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.dropout = nn.Dropout(dropout_p) + # relative positional encoding layers + if rel_attn_window_size is not None: + n_heads_rel = 1 if heads_share else num_heads + rel_stddev = self.k_channels**-0.5 + emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, + self.k_channels) * rel_stddev) + emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, + self.k_channels) * rel_stddev) + self.register_parameter('emb_rel_k', emb_rel_k) + self.register_parameter('emb_rel_v', emb_rel_v) + + # init layers + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + # proximal bias + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.num_heads, self.k_channels, + t_t).transpose(2, 3) + key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, + t_s).transpose(2, 3) + # compute raw attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( + self.k_channels) + # relative positional encoding + if self.rel_attn_window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + # get relative key embeddings + key_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position( + rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + # proximan bias + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attn_proximity_bias(t_s).to( + device=scores.device, dtype=scores.dtype) + # attention score masking + if mask is not None: + # add small value to prevent oor error. + scores = scores.masked_fill(mask == 0, -1e4) + if self.input_length is not None: + block_mask = torch.ones_like(scores).triu( + -self.input_length).tril(self.input_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + # attention score normalization + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + # apply dropout to attention weights + p_attn = self.dropout(p_attn) + # compute output + output = torch.matmul(p_attn, value) + # relative positional encoding for values + if self.rel_attn_window_size is not None: + relative_weights = self._absolute_position_to_relative_position( + p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view( + b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, p_attn, re): + """ + Args: + p_attn (Tensor): attention weights. + re (Tensor): relative value embedding vector. (a_(i,j)^V) + + Shapes: + p_attn: [B, H, T, V] + re: [H or 1, V, D] + logits: [B, H, T, D] + """ + logits = torch.matmul(p_attn, re.unsqueeze(0)) + return logits + + @staticmethod + def _matmul_with_relative_keys(query, re): + """ + Args: + query (Tensor): batch of query vectors. (x*W^Q) + re (Tensor): relative key embedding vector. (a_(i,j)^K) + + Shapes: + query: [B, H, T, D] + re: [H or 1, V, D] + logits: [B, H, T, V] + """ + # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) + logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) + return logits + + def _get_relative_embeddings(self, relative_embeddings, length): + """Convert embedding vestors to a tensor of embeddings + """ + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.rel_attn_window_size + 1), 0) + slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position: + slice_end_position] + return used_relative_embeddings + + @staticmethod + def _relative_position_to_absolute_position(x): + """Converts tensor from relative to absolute indexing for local attention. + Args: + x: [B, D, length, 2 * length - 1] + Returns: + A Tensor of shape [B, D, length, length] + """ + batch, heads, length, _ = x.size() + # Pad to shift from relative to absolute indexing. + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + # Pad extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, + 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + @staticmethod + def _absolute_position_to_relative_position(x): + """ + x: [B, H, T, T] + ret: [B, H, T, 2*T-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + @staticmethod + def _attn_proximity_bias(length): + """Produce an attention mask that discourages distant + attention values. + Args: + length (int): an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + # L + r = torch.arange(length, dtype=torch.float32) + # L x L + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + # scale mask values + diff = -torch.log1p(torch.abs(diff)) + # 1 x 1 x L x L + return diff.unsqueeze(0).unsqueeze(0) + + +class FFN(nn.Module): + def __init__(self, + in_channels, + out_channels, + filter_channels, + kernel_size, + dropout_p=0., + activation=None): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.activation = activation + + self.conv_1 = nn.Conv1d(in_channels, + filter_channels, + kernel_size, + padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(filter_channels, + out_channels, + kernel_size, + padding=kernel_size // 2) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.dropout(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Transformer(nn.Module): + def __init__(self, + hidden_channels, + filter_channels, + num_heads, + num_layers, + kernel_size=1, + dropout_p=0., + rel_attn_window_size=None, + input_length=None): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for _ in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + dropout_p=dropout_p)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 008a9dd6..4bc31d90 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,3 +1,4 @@ +import math import numpy as np import torch from torch import nn @@ -150,7 +151,7 @@ class GuidedAttentionLoss(torch.nn.Module): @staticmethod def _make_ga_mask(ilen, olen, sigma): - grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=ilen.device)) + grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2))) @@ -243,3 +244,27 @@ class TacotronLoss(torch.nn.Module): return_dict['loss'] = loss return return_dict + + +class GlowTTSLoss(torch.nn.Module): + def __init__(self): + super(GlowTTSLoss, self).__init__() + self.constant_factor = 0.5 * math.log(2 * math.pi) + + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, + o_attn_dur, x_lengths): + return_dict = {} + # flow loss - neg log likelihood + pz = torch.sum(scales) + 0.5 * torch.sum( + torch.exp(-2 * scales) * (z - means)**2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / ( + torch.sum(y_lengths // 2) * 2 * z.shape[1]) + # duration loss - MSE + # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) + # duration loss - huber loss + loss_dur = torch.nn.functional.smooth_l1_loss( + o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths) + return_dict['loss'] = log_mle + loss_dur + return_dict['log_mle'] = log_mle + return_dict['loss_dur'] = loss_dur + return return_dict \ No newline at end of file diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py new file mode 100644 index 00000000..50f08c93 --- /dev/null +++ b/TTS/tts/models/glow_tts.py @@ -0,0 +1,185 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path + + +class GlowTts(nn.Module): + """Glow TTS models from https://arxiv.org/abs/2005.11129""" + def __init__(self, + num_chars, + hidden_channels, + filter_channels, + filter_channels_dp, + out_channels, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer"): + + super().__init__() + self.num_chars = num_chars + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.filter_channels_dp = filter_channels_dp + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_heads = num_heads + self.num_layers_enc = num_layers_enc + self.dropout_p = dropout_p + self.num_flow_blocks_dec = num_flow_blocks_dec + self.kernel_size_dec = kernel_size_dec + self.dilation_rate = dilation_rate + self.num_block_layers = num_block_layers + self.dropout_p_dec = dropout_p_dec + self.num_speakers = num_speakers + self.c_in_channels = c_in_channels + self.num_splits = num_splits + self.num_sqz = num_sqz + self.sigmoid_scale = sigmoid_scale + self.rel_attn_window_size = rel_attn_window_size + self.input_length = input_length + self.mean_only = mean_only + self.hidden_channels_enc = hidden_channels_enc + self.hidden_channels_dec = hidden_channels_dec + self.use_encoder_prenet = use_encoder_prenet + self.noise_scale=0.66 + self.length_scale=1. + + self.encoder = Encoder(num_chars, + out_channels=out_channels, + hidden_channels=hidden_channels, + filter_channels=filter_channels, + filter_channels_dp=filter_channels_dp, + encoder_type=encoder_type, + num_heads=num_heads, + num_layers=num_layers_enc, + kernel_size=kernel_size, + dropout_p=dropout_p, + mean_only=mean_only, + use_prenet=use_encoder_prenet, + c_in_channels=c_in_channels) + + self.decoder = Decoder(out_channels, + hidden_channels_dec or hidden_channels, + kernel_size_dec, + dilation_rate, + num_flow_blocks_dec, + num_block_layers, + dropout_p=dropout_p_dec, + num_splits=num_splits, + num_sqz=num_sqz, + sigmoid_scale=sigmoid_scale, + c_in_channels=c_in_channels) + + if num_speakers > 1: + self.emb_g = nn.Embedding(num_speakers, c_in_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + def compute_outputs(self, attn, o_mean, o_log_scale, x_mask): + # compute final values with the computed alignment + y_mean = torch.matmul( + attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul( + attn.squeeze(1).transpose(1, 2), o_log_scale.transpose( + 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + # compute total duration with adjustment + o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask + return y_mean, y_log_scale, o_attn_dur + + def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + """ + Shapes: + x: B x T + x_lenghts: B + y: B x C x T + y_lengths: B + """ + y_max_length = y.size(2) + # norm speaker embeddings + if g is not None: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # format feature vectors and feature vector lenghts + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, + attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs( + attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + + @torch.no_grad() + def inference(self, x, x_lengths, g=None): + if g is not None: + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # compute output durations + w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = None + # compute masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # compute attention mask + attn = generate_path(w_ceil.squeeze(1), + attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs( + attn, o_mean, o_log_scale, x_mask) + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * + self.noise_scale) * y_mask + # decoder pass + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + attn = attn.squeeze(1).permute(0, 2, 1) + return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + + def preprocess(self, y, y_lengths, y_max_length, attn=None): + if y_max_length is not None: + y_max_length = (y_max_length // self.num_sqz) * self.num_sqz + y = y[:, :, :y_max_length] + if attn is not None: + attn = attn[:, :, :, :y_max_length] + y_lengths = (y_lengths // self.num_sqz) * self.num_sqz + return y, y_lengths, y_max_length, attn + + def store_inverse(self): + self.decoder.store_inverse() diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 6358e5a9..0ff462dd 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -1,3 +1,4 @@ +import re import torch import importlib import numpy as np @@ -31,21 +32,22 @@ def split_dataset(items): def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() - batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).long() - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.to(sequence_length.device) - seq_length_expand = ( - sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + seq_range = torch.arange(max_len, + dtype=sequence_length.dtype, + device=sequence_length.device) # B x T_max - return seq_range_expand < seq_length_expand + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) + + +def to_camel(text): + text = text.capitalize() + return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) - MyModel = getattr(MyModel, c.model) + MyModel = getattr(MyModel, to_camel(c.model)) if c.model.lower() in "tacotron": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, @@ -97,6 +99,31 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): double_decoder_consistency=c.double_decoder_consistency, ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) + elif c.model.lower() == "glow_tts": + model = MyModel(num_chars=num_chars, + hidden_channels=192, + filter_channels=768, + filter_channels_dp=256, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + encoder_type=c.encoder_type, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.05, + num_speakers=num_speakers, + c_in_channels=0, + num_splits=4, + num_sqz=2, + sigmoid_scale=False, + mean_only=True, + hidden_channels_enc=192, + hidden_channels_dec=192, + use_encoder_prenet=True) return model diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index 78e9b8b2..2bc755b4 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -18,7 +18,7 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): if use_cuda: model.cuda() # set model stepsize - if 'r' in state.keys(): + if hasattr(model.decoder, 'r'): model.decoder.set_r(state['r']) return model, state diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 76ac7909..48083a2a 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -46,16 +46,24 @@ def compute_style_mel(style_wav, ap, cuda=False): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - if CONFIG.use_gst: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - else: - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - else: + if 'tacotron' in CONFIG.model.lower(): + if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + else: + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + elif 'glow' in CONFIG.model.lower(): + inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) + postnet_output = postnet_output.permute(0, 2, 1) + # these only belong to tacotron models. + decoder_output = None + stop_tokens = None return decoder_output, postnet_output, alignments, stop_tokens @@ -99,9 +107,9 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() - decoder_output = decoder_output[0].data.cpu().numpy() + decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() - stop_tokens = stop_tokens[0].cpu().numpy() + stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy() return postnet_output, decoder_output, alignment, stop_tokens diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 500d7707..033a5191 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -6,14 +6,20 @@ import matplotlib.pyplot as plt from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): +def plot_alignment(alignment, + info=None, + fig_size=(16, 10), + title=None, + output_fig=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment fig, ax = plt.subplots(figsize=fig_size) - im = ax.imshow( - alignment_.T, aspect='auto', origin='lower', interpolation='none') + im = ax.imshow(alignment_.T, + aspect='auto', + origin='lower', + interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: @@ -29,7 +35,10 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_f return fig -def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): +def plot_spectrogram(spectrogram, + ap=None, + fig_size=(16, 10), + output_fig=False): if isinstance(spectrogram, torch.Tensor): spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: @@ -45,7 +54,17 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): return fig -def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False): +def visualize(alignment, + postnet_output, + text, + hop_length, + CONFIG, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False): + if decoder_output is not None: num_plot = 4 else: @@ -60,18 +79,30 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, plt.ylabel("Encoder timestamp", fontsize=label_fontsize) # compute phoneme representation and back if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) - text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + seq = phoneme_to_sequence( + text, [CONFIG.text_cleaner], + CONFIG.phoneme_language, + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + text = sequence_to_phoneme( + seq, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() - # plot stopnet predictions - plt.subplot(num_plot, 1, 2) - plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + if stop_tokens is not None: + # plot stopnet predictions + plt.subplot(num_plot, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + # plot postnet spectrogram plt.subplot(num_plot, 1, 3) - librosa.display.specshow(postnet_output.T, sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, x_axis="time", y_axis="linear", + librosa.display.specshow(postnet_output.T, + sr=CONFIG.audio['sample_rate'], + hop_length=hop_length, + x_axis="time", + y_axis="linear", fmin=CONFIG.audio['mel_fmin'], fmax=CONFIG.audio['mel_fmax']) @@ -82,8 +113,11 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, if decoder_output is not None: plt.subplot(num_plot, 1, 4) - librosa.display.specshow(decoder_output.T, sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, x_axis="time", y_axis="linear", + librosa.display.specshow(decoder_output.T, + sr=CONFIG.audio['sample_rate'], + hop_length=hop_length, + x_axis="time", + y_axis="linear", fmin=CONFIG.audio['mel_fmin'], fmax=CONFIG.audio['mel_fmax']) plt.xlabel("Time", fontsize=label_fontsize) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 7a13d14b..aaa14dfd 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -174,8 +174,9 @@ class AudioProcessor(object): for key in stats_config.keys(): if key in skip_parameters: continue - assert stats_config[key] == self.__dict__[key],\ - f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + if key != 'sample_rate': + assert stats_config[key] == self.__dict__[key],\ + f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" return mel_mean, mel_std, linear_mean, linear_std, stats_config # pylint: disable=attribute-defined-outside-init @@ -322,6 +323,7 @@ class AudioProcessor(object): def load_wav(self, filename, sr=None): if sr is None: x, sr = sf.read(filename) + assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) else: x, sr = librosa.load(filename, sr=sr) if self.do_trim_silence: @@ -329,7 +331,6 @@ class AudioProcessor(object): x = self.trim_silence(x) except ValueError: print(f' [!] File cannot be trimmed for silence - {filename}') - assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) if self.do_sound_norm: x = self.sound_norm(x) return x diff --git a/TTS/utils/io.py b/TTS/utils/io.py index c54d2e9f..07ec63a0 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -5,8 +5,8 @@ import pickle as pickle_tts class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" def find_class(self, module, name): - if 'mozilla_voice_tts' in module : - module = module.replace('mozilla_voice_tts', 'TTS') + if 'TTS' in module : + module = module.replace('TTS', 'TTS') return super().find_class(module, name) class AttrDict(dict): diff --git a/TTS/vocoder/layers/pqmf.py b/TTS/vocoder/layers/pqmf.py index ef5a3507..d31953d6 100644 --- a/TTS/vocoder/layers/pqmf.py +++ b/TTS/vocoder/layers/pqmf.py @@ -22,7 +22,7 @@ class PQMF(torch.nn.Module): for k in range(N): constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) + ((taps - 1) / 2)) # TODO: (taps - 1) -> taps phase = (-1)**k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py new file mode 100644 index 00000000..9f90ee17 --- /dev/null +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -0,0 +1,31 @@ +import torch + +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class FullbandMelganGenerator(MelganGenerator): + def __init__(self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4): + super(FullbandMelganGenerator, + self).__init__(in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, + (self.inference_padding, self.inference_padding), + 'replicate') + return self.layers(cond_features) diff --git a/TTS/vocoder/tf/layers/melgan.py b/TTS/vocoder/tf/layers/melgan.py index f9806579..34b25d65 100644 --- a/TTS/vocoder/tf/layers/melgan.py +++ b/TTS/vocoder/tf/layers/melgan.py @@ -48,7 +48,6 @@ class ResidualStack(tf.keras.layers.Layer): ] def call(self, x): - # breakpoint() for block, shortcut in zip(self.blocks, self.shortcuts): res = shortcut(x) for layer in block: diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 4b9a7c3f..89dc68fb 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -67,6 +67,15 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) + if c.generator_model in 'fullband_melgan_generator': + model = MyModel( + in_channels=c.audio['num_mels'], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params['upsample_factors'], + res_kernel=3, + num_res_blocks=c.generator_model_params['num_res_blocks']) if c.generator_model in 'parallel_wavegan_generator': model = MyModel( in_channels=1, diff --git a/setup.py b/setup.py index 4aa3d52a..3126cf6d 100644 --- a/setup.py +++ b/setup.py @@ -5,11 +5,21 @@ import os import shutil import subprocess import sys +import numpy from setuptools import setup, find_packages import setuptools.command.develop import setuptools.command.build_py +# handle import if cython is not already installed. +try: + from Cython.Build import cythonize +except ImportError: + # create closure for deferred import + def cythonize (*args, ** kwargs ): + from Cython.Build import cythonize + return cythonize(*args, ** kwargs) + parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file to embed in wheel.') @@ -36,6 +46,16 @@ else: pass +# Handle Cython code +def find_pyx(path='.'): + pyx_files = [] + for root, dirs, filenames in os.walk(path): + for fname in filenames: + if fname.endswith('.pyx'): + pyx_files.append(os.path.join(root, fname)) + return pyx_files + + class build_py(setuptools.command.build_py.build_py): # pylint: disable=too-many-ancestors def run(self): self.create_version_file() @@ -99,6 +119,8 @@ setup( 'tts-server = TTS.server.server:main' ] }, + include_dirs=[numpy.get_include()], + ext_modules=cythonize(find_pyx(), language_level=3), packages=find_packages(include=['TTS*']), project_urls={ 'Documentation': 'https://github.com/mozilla/TTS/wiki', diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py new file mode 100644 index 00000000..6f3cdb81 --- /dev/null +++ b/tests/test_glow_tts.py @@ -0,0 +1,135 @@ +import copy +import os +import unittest + +import torch +from tests import get_tests_input_path +from torch import nn, optim + +from TTS.tts.layers.losses import GlowTTSLoss +from TTS.tts.models.glow_tts import GlowTts +from TTS.utils.io import load_config +from TTS.utils.audio import AudioProcessor + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) + +ap = AudioProcessor(**c.audio) +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class GlowTTSTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + criterion = criterion = GlowTTSLoss() + + # model to train + model = GlowTts( + num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer" + ).to(device) + + # reference model to compare model weights + model_ref = GlowTts( + num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer" + ).to(device) + + model.train() + print(" > Num parameters for GlowTTS model:%s" % + (count_parameters(model))) + + # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for _ in range(5): + z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, None) + optimizer.zero_grad() + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, + o_dur_log, o_total_dur, input_lengths) + loss = loss_dict['loss'] + loss.backward() + optimizer.step() + + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 \ No newline at end of file