diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 3d71dbd5..cf9d98d2 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -10,37 +10,32 @@ import traceback import numpy as np import torch -from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel as DDP - +from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss -from TTS.utils.console_logger import ConsoleLogger -from TTS.tts.utils.distribute import (DistributedSampler, - init_distributed, - reduce_tensor) +from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, + reduce_tensor) from TTS.tts.utils.generic_utils import check_config, setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import (get_speakers, - load_speaker_mapping, - save_speaker_mapping) +from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, + save_speaker_mapping) from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import ( - KeepAverage, count_parameters, create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.generic_utils import (KeepAverage, count_parameters, + create_experiment_folder, get_git_branch, + remove_experiment_folder, set_init_dict) from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger -from TTS.utils.training import (NoamLR, adam_weight_decay, - check_update, - gradual_training_scheduler, - set_weight_decay, - setup_torch_training_env) +from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, + gradual_training_scheduler, set_weight_decay, + setup_torch_training_env) use_cuda, num_gpus = setup_torch_training_env(True, False) @@ -116,7 +111,7 @@ def format_data(data): def data_depended_init(model, ap): - """Data depended initialization for normalization layers.""" + """Data depended initialization for activation normalization.""" if hasattr(model, 'module'): for f in model.module.decoder.flows: if getattr(f, "set_ddi", False): diff --git a/TTS/tts/configs/glow_tts_gated_conv.json b/TTS/tts/configs/glow_tts_gated_conv.json new file mode 100644 index 00000000..696bdaf7 --- /dev/null +++ b/TTS/tts/configs/glow_tts_gated_conv.json @@ -0,0 +1,132 @@ +{ + "model": "glow_tts", + "run_name": "glow-tts-gatedconv", + "run_description": "glow-tts model training with gated conv.", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Griffin-Lim + "power": 1.1, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // MODEL PARAMETERS + "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. + + // TRAINING + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":16, + "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": true, // use noam warmup and lr schedule. + "grad_clip": 5.0, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate. + "wd": 0.000001, // Weight decay weight. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. + + "encoder_type": "gatedconv", + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "apex_amp_level": null, + + // DATA LOADING + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 500, // DATASET-RELATED: maximum text length + "compute_f0": false, // compute f0 values in data-loader + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/", + + // PHONEMES + "phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. + "use_gst": false, // TACOTRON ONLY: use global style tokens + + // DATASETS + "datasets": // List of datasets. They all merged and they get different speaker_ids. + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null + // "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/" + } + ] +} + + diff --git a/TTS/tts/configs/glow_tts_tdsep.json b/TTS/tts/configs/glow_tts_tdsep.json new file mode 100644 index 00000000..67047523 --- /dev/null +++ b/TTS/tts/configs/glow_tts_tdsep.json @@ -0,0 +1,132 @@ +{ + "model": "glow_tts", + "run_name": "glow-tts-tdsep-conv", + "run_description": "glow-tts model training with time-depth separable conv encoder.", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Griffin-Lim + "power": 1.1, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // MODEL PARAMETERS + "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. + + // TRAINING + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":16, + "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": true, // use noam warmup and lr schedule. + "grad_clip": 5.0, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate. + "wd": 0.000001, // Weight decay weight. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. + + "encoder_type": "time-depth-separable", + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "apex_amp_level": null, + + // DATA LOADING + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 500, // DATASET-RELATED: maximum text length + "compute_f0": false, // compute f0 values in data-loader + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/", + + // PHONEMES + "phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. + "use_gst": false, // TACOTRON ONLY: use global style tokens + + // DATASETS + "datasets": // List of datasets. They all merged and they get different speaker_ids. + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null + // "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/" + } + ] + } + + diff --git a/TTS/tts/layers/glow_tts/gated_conv.py b/TTS/tts/layers/glow_tts/gated_conv.py new file mode 100644 index 00000000..2417ea63 --- /dev/null +++ b/TTS/tts/layers/glow_tts/gated_conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn + +from .normalization import LayerNorm + + +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [ + nn.Conv1d(in_out_channels, + 2 * in_out_channels, + kernel_size, + padding=kernel_size // 2) + ] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, + p=self.dropout_p, + training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o \ No newline at end of file diff --git a/TTS/tts/layers/glow_tts/normalization.py b/TTS/tts/layers/glow_tts/normalization.py new file mode 100644 index 00000000..70444abc --- /dev/null +++ b/TTS/tts/layers/glow_tts/normalization.py @@ -0,0 +1,101 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + """Layer norm for the 2nd dimension of the input. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1) + self.beta = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x): + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean)**2, 1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + x = x * self.gamma + self.beta + return x + + +class TemporalBatchNorm1d(nn.BatchNorm1d): + """Normalize each channel separately over time and batch. + """ + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + + def forward(self, x): + return super().forward(x.transpose(2,1)).transpose(2,1) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, + dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( + dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to( + dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) \ No newline at end of file diff --git a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py new file mode 100644 index 00000000..19fc7035 --- /dev/null +++ b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py @@ -0,0 +1,94 @@ +import torch +from torch import nn + +from .normalization import LayerNorm + + +class TimeDepthSeparableConv(nn.Module): + """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf + It shows competative results with less computation and memory footprint.""" + def __init__(self, + in_channels, + hid_channels, + out_channels, + kernel_size, + bias=True): + super(TimeDepthSeparableConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hid_channels = hid_channels + self.kernel_size = kernel_size + + self.time_conv = nn.Conv1d( + in_channels, + 2 * hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm1 = nn.BatchNorm1d(2 * hid_channels) + self.depth_conv = nn.Conv1d( + hid_channels, + hid_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=hid_channels, + bias=bias, + ) + self.norm2 = nn.BatchNorm1d(hid_channels) + self.time_conv2 = nn.Conv1d( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm3 = nn.BatchNorm1d(out_channels) + + def forward(self, x): + x_res = x + x = self.time_conv(x) + x = self.norm1(x) + x = nn.functional.glu(x, dim=1) + x = self.depth_conv(x) + x = self.norm2(x) + x = x * torch.sigmoid(x) + x = self.time_conv2(x) + x = self.norm3(x) + x = x_res + x + return x + + +class TimeDepthSeparableConvBlock(nn.Module): + def __init__(self, + in_channels, + hid_channels, + out_channels, + num_layers, + kernel_size, + bias=True): + super(TimeDepthSeparableConvBlock, self).__init__() + assert (kernel_size - 1) % 2 == 0 + assert num_layers > 1 + + self.layers = nn.ModuleList() + layer = TimeDepthSeparableConv( + in_channels, hid_channels, + out_channels if num_layers == 1 else hid_channels, kernel_size, + bias) + self.layers.append(layer) + for idx in range(num_layers - 1): + layer = TimeDepthSeparableConv( + hid_channels, hid_channels, out_channels if + (idx + 1) == (num_layers - 1) else hid_channels, kernel_size, + bias) + self.layers.append(layer) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x * mask) + return x diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2ef92422..50f08c93 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -112,10 +112,11 @@ class GlowTts(nn.Module): def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ - x: B x T - x_lenghts: B - y: B x D x T - y_lengths: B + Shapes: + x: B x T + x_lenghts: B + y: B x C x T + y_lengths: B """ y_max_length = y.size(2) # norm speaker embeddings diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py new file mode 100644 index 00000000..6f3cdb81 --- /dev/null +++ b/tests/test_glow_tts.py @@ -0,0 +1,135 @@ +import copy +import os +import unittest + +import torch +from tests import get_tests_input_path +from torch import nn, optim + +from TTS.tts.layers.losses import GlowTTSLoss +from TTS.tts.models.glow_tts import GlowTts +from TTS.utils.io import load_config +from TTS.utils.audio import AudioProcessor + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) + +ap = AudioProcessor(**c.audio) +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class GlowTTSTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + criterion = criterion = GlowTTSLoss() + + # model to train + model = GlowTts( + num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer" + ).to(device) + + # reference model to compare model weights + model_ref = GlowTts( + num_chars=32, + hidden_channels=128, + filter_channels=32, + filter_channels_dp=32, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dec=0., + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_sqz=1, + sigmoid_scale=False, + rel_attn_window_size=None, + input_length=None, + mean_only=False, + hidden_channels_enc=None, + hidden_channels_dec=None, + use_encoder_prenet=False, + encoder_type="transformer" + ).to(device) + + model.train() + print(" > Num parameters for GlowTTS model:%s" % + (count_parameters(model))) + + # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for _ in range(5): + z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, None) + optimizer.zero_grad() + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, + o_dur_log, o_total_dur, input_lengths) + loss = loss_dict['loss'] + loss.backward() + optimizer.step() + + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 \ No newline at end of file