diff --git a/TTS/tts/configs/glow_tts_tdsep.json b/TTS/tts/configs/glow_tts_tdsep.json index 25d41291..72eb3da7 100644 --- a/TTS/tts/configs/glow_tts_tdsep.json +++ b/TTS/tts/configs/glow_tts_tdsep.json @@ -1,7 +1,7 @@ { "model": "glow_tts", - "run_name": "glow-tts-tdsep-conv", - "run_description": "glow-tts model training with time-depth separable conv encoder.", + "run_name": "glow-tts-residual_bn_conv", + "run_description": "glow-tts model training with residual BN conv.", // AUDIO PARAMETERS "audio":{ @@ -28,7 +28,7 @@ "num_mels": 80, // size of the mel spec frame. "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! - "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.00 // Normalization parameters "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. @@ -62,13 +62,15 @@ "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. // MODEL PARAMETERS - "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. + // "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. // TRAINING "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "loss_masking": true, // enable / disable loss masking against the sequence padding. + "mixed_precision": true, + "data_dep_init_iter": 10, // VALIDATION "run_eval": true, @@ -84,7 +86,7 @@ "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. - "encoder_type": "time-depth-separable", + "encoder_type": "residual_conv_bn", // TENSORBOARD and LOGGING "print_step": 25, // Number of steps to log training on console. @@ -93,7 +95,6 @@ "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. "checkpoint": true, // If true, it saves checkpoints per "save_step" "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "apex_amp_level": null, // DATA LOADING "text_cleaner": "phoneme_cleaners", @@ -104,6 +105,7 @@ "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 500, // DATASET-RELATED: maximum text length "compute_f0": false, // compute f0 values in data-loader + "compute_input_seq_cache": true, // PATHS "output_path": "/home/erogol/Models/LJSpeech/", @@ -115,6 +117,7 @@ // MULTI-SPEAKER and GST "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "use_external_speaker_embedding_file": false, "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. "use_gst": false, // TACOTRON ONLY: use global style tokens diff --git a/TTS/tts/layers/generic/__init__.py b/TTS/tts/layers/generic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/glow_tts/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py similarity index 100% rename from TTS/tts/layers/glow_tts/gated_conv.py rename to TTS/tts/layers/generic/gated_conv.py diff --git a/TTS/tts/layers/glow_tts/normalization.py b/TTS/tts/layers/generic/normalization.py similarity index 100% rename from TTS/tts/layers/glow_tts/normalization.py rename to TTS/tts/layers/generic/normalization.py diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py new file mode 100644 index 00000000..f67a8613 --- /dev/null +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -0,0 +1,66 @@ +import torch +from torch import nn +from .normalization import TemporalBatchNorm1d + + +class ZeroTemporalPad(nn.ZeroPad2d): + """Pad sequences to equal lentgh in the temporal dimension""" + def __init__(self, kernel_size, dilation): + total_pad = (dilation * (kernel_size - 1)) + begin = total_pad // 2 + end = total_pad - begin + super(ZeroTemporalPad, self).__init__((0, 0, begin, end)) + + +class ConvBN(nn.Module): + def __init__(self, channels, kernel_size, dilation): + super().__init__() + padding = (dilation * (kernel_size - 1)) + pad_s = padding // 2 + pad_e = padding - pad_s + self.conv1d = nn.Conv1d(channels, channels, kernel_size, dilation=dilation) + self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding + self.norm = nn.BatchNorm1d(channels) + + def forward(self, x): + o = self.conv1d(x) + o = self.pad(o) + o = self.norm(o) + o = nn.functional.relu(o) + return o + + +class ConvBNBlock(nn.Module): + """Implements conv->PReLU->norm n-times""" + + def __init__(self, channels, kernel_size, dilation, num_conv_blocks=2): + super().__init__() + self.conv_bn_blocks = nn.Sequential(*[ + ConvBN(channels, kernel_size, dilation) + for _ in range(num_conv_blocks) + ]) + + def forward(self, x): + """ + Shapes: + x: (B, D, T) + """ + return self.conv_bn_blocks(x) + + +class ResidualConvBNBlock(nn.Module): + def __init__(self, channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2): + super().__init__() + assert len(dilations) == num_res_blocks + self.res_blocks = nn.ModuleList() + for dilation in dilations: + block = ConvBNBlock(channels, kernel_size, dilation, num_conv_blocks) + self.res_blocks.append(block) + + def forward(self, x, x_mask=None): + o = x + for block in self.res_blocks: + res = o + o = block(o * x_mask if x_mask is not None else o) + o = o + res + return o diff --git a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py b/TTS/tts/layers/generic/time_depth_sep_conv.py similarity index 100% rename from TTS/tts/layers/glow_tts/time_depth_sep_conv.py rename to TTS/tts/layers/generic/time_depth_sep_conv.py diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py index 67329a2a..6788132e 100644 --- a/TTS/tts/layers/glow_tts/decoder.py +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -2,7 +2,7 @@ import torch from torch import nn from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock -from TTS.tts.layers.glow_tts.normalization import ActNorm +from TTS.tts.layers.generic.normalization import ActNorm def squeeze(x, x_mask=None, num_sqz=2): diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index 9f825832..b6383674 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -1,7 +1,7 @@ import torch from torch import nn -from .normalization import LayerNorm +from ..generic.normalization import LayerNorm class DurationPredictor(nn.Module): diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index c5af85ec..e9b19aa4 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -3,11 +3,12 @@ import torch from torch import nn from TTS.tts.layers.glow_tts.transformer import Transformer -from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock +from TTS.tts.layers.generic.gated_conv import GatedConvBlock from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.glow import ConvLayerNorm from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor -from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock class Encoder(nn.Module): @@ -84,12 +85,26 @@ class Encoder(nn.Module): dropout_p=dropout_p, rel_attn_window_size=rel_attn_window_size, input_length=input_length) - elif encoder_type.lower() == 'gatedconv': + elif encoder_type.lower() == 'gated_conv': self.encoder = GatedConvBlock(hidden_channels, kernel_size=5, dropout_p=dropout_p, num_layers=3 + num_layers) - elif encoder_type.lower() == 'time-depth-separable': + elif encoder_type.lower() == 'residual_conv_bn': + if use_prenet: + self.pre = nn.Sequential( + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU() + ) + dilations = 4 * [1, 2, 4] + [1] + num_conv_blocks = 2 + num_res_blocks = 13 # total 2 * 13 blocks + self.encoder = ResidualConvBNBlock(hidden_channels, + kernel_size=4, + dilations=dilations, + num_res_blocks=num_res_blocks, + num_conv_blocks=num_conv_blocks) + elif encoder_type.lower() == 'time_depth_separable': # optional convolutional prenet if use_prenet: self.pre = ConvLayerNorm(hidden_channels, diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index b06dd8a5..7b394e43 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn import functional as F -from .normalization import LayerNorm +from ..generic.normalization import LayerNorm class ConvLayerNorm(nn.Module): diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 734c11e5..91ee3fa3 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -141,7 +141,7 @@ class GlowTts(nn.Module): o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) - # format feature vectors and feature vector lenghts + # drop redisual frames wrt num_sqz and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks