mirror of https://github.com/coqui-ai/TTS.git
glow-tts modules added
This commit is contained in:
parent
e4c6386603
commit
e0b9fa887f
|
@ -10,37 +10,32 @@ import traceback
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.tts.utils.distribute import (DistributedSampler,
|
||||
init_distributed,
|
||||
reduce_tensor)
|
||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||
reduce_tensor)
|
||||
from TTS.tts.utils.generic_utils import check_config, setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import (get_speakers,
|
||||
load_speaker_mapping,
|
||||
save_speaker_mapping)
|
||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||
save_speaker_mapping)
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (
|
||||
KeepAverage, count_parameters, create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import (NoamLR, adam_weight_decay,
|
||||
check_update,
|
||||
gradual_training_scheduler,
|
||||
set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||
gradual_training_scheduler, set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
@ -116,7 +111,7 @@ def format_data(data):
|
|||
|
||||
|
||||
def data_depended_init(model, ap):
|
||||
"""Data depended initialization for normalization layers."""
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, 'module'):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
{
|
||||
"model": "glow_tts",
|
||||
"run_name": "glow-tts-gatedconv",
|
||||
"run_description": "glow-tts model training with gated conv.",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.1, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "~",
|
||||
// "bos": "^",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 5.0, // upper limit for gradients for clipping.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
"encoder_type": "gatedconv",
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"apex_amp_level": null,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
// "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
{
|
||||
"model": "glow_tts",
|
||||
"run_name": "glow-tts-tdsep-conv",
|
||||
"run_description": "glow-tts model training with time-depth separable conv encoder.",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.1, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "~",
|
||||
// "bos": "^",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 5.0, // upper limit for gradients for clipping.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
"encoder_type": "time-depth-separable",
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"apex_amp_level": null,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
// "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class GatedConvBlock(nn.Module):
|
||||
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
|
||||
Args:
|
||||
in_out_channels (int): number of input/output channels.
|
||||
kernel_size (int): convolution kernel size.
|
||||
dropout_p (float): dropout rate.
|
||||
"""
|
||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.dropout_p = dropout_p
|
||||
self.num_layers = num_layers
|
||||
# define layers
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.conv_layers += [
|
||||
nn.Conv1d(in_out_channels,
|
||||
2 * in_out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
]
|
||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
o = x
|
||||
res = x
|
||||
for idx in range(self.num_layers):
|
||||
o = nn.functional.dropout(o,
|
||||
p=self.dropout_p,
|
||||
training=self.training)
|
||||
o = self.conv_layers[idx](o * x_mask)
|
||||
o = self.norm_layers[idx](o)
|
||||
o = nn.functional.glu(o, dim=1)
|
||||
o = res + o
|
||||
res = o
|
||||
return o
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
"""Layer norm for the 2nd dimension of the input.
|
||||
Args:
|
||||
channels (int): number of channels (2nd dimension) of the input.
|
||||
eps (float): to prevent 0 division
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
- output: (B, C, T)
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
x = x * self.gamma + self.beta
|
||||
return x
|
||||
|
||||
|
||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||
"""Normalize each channel separately over time and batch.
|
||||
"""
|
||||
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
|
||||
super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.transpose(2,1)).transpose(2,1)
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
|
||||
mean and std from a sample data in advance and it uses these values
|
||||
for normalization at training.
|
||||
|
||||
Args:
|
||||
channels (int): input channels.
|
||||
ddi (False): data depended initialization flag.
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, C, T)
|
||||
- outputs: (B, C, T)
|
||||
"""
|
||||
|
||||
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.initialized = not ddi
|
||||
|
||||
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
self.initialized = True
|
||||
|
||||
if reverse:
|
||||
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
||||
logdet = None
|
||||
else:
|
||||
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
||||
logdet = torch.sum(self.logs) * x_len # [b]
|
||||
|
||||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
pass
|
||||
|
||||
def set_ddi(self, ddi):
|
||||
self.initialized = not ddi
|
||||
|
||||
def initialize(self, x, x_mask):
|
||||
with torch.no_grad():
|
||||
denom = torch.sum(x_mask, [0, 2])
|
||||
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||||
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
|
||||
dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(
|
||||
dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
self.logs.data.copy_(logs_init)
|
|
@ -0,0 +1,94 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class TimeDepthSeparableConv(nn.Module):
|
||||
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
|
||||
It shows competative results with less computation and memory footprint."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
super(TimeDepthSeparableConv, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hid_channels = hid_channels
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.time_conv = nn.Conv1d(
|
||||
in_channels,
|
||||
2 * hid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm1 = nn.BatchNorm1d(2 * hid_channels)
|
||||
self.depth_conv = nn.Conv1d(
|
||||
hid_channels,
|
||||
hid_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=hid_channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm2 = nn.BatchNorm1d(hid_channels)
|
||||
self.time_conv2 = nn.Conv1d(
|
||||
hid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm3 = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x_res = x
|
||||
x = self.time_conv(x)
|
||||
x = self.norm1(x)
|
||||
x = nn.functional.glu(x, dim=1)
|
||||
x = self.depth_conv(x)
|
||||
x = self.norm2(x)
|
||||
x = x * torch.sigmoid(x)
|
||||
x = self.time_conv2(x)
|
||||
x = self.norm3(x)
|
||||
x = x_res + x
|
||||
return x
|
||||
|
||||
|
||||
class TimeDepthSeparableConvBlock(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
super(TimeDepthSeparableConvBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
assert num_layers > 1
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
layer = TimeDepthSeparableConv(
|
||||
in_channels, hid_channels,
|
||||
out_channels if num_layers == 1 else hid_channels, kernel_size,
|
||||
bias)
|
||||
self.layers.append(layer)
|
||||
for idx in range(num_layers - 1):
|
||||
layer = TimeDepthSeparableConv(
|
||||
hid_channels, hid_channels, out_channels if
|
||||
(idx + 1) == (num_layers - 1) else hid_channels, kernel_size,
|
||||
bias)
|
||||
self.layers.append(layer)
|
||||
|
||||
def forward(self, x, mask):
|
||||
for layer in self.layers:
|
||||
x = layer(x * mask)
|
||||
return x
|
|
@ -112,10 +112,11 @@ class GlowTts(nn.Module):
|
|||
|
||||
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
"""
|
||||
x: B x T
|
||||
x_lenghts: B
|
||||
y: B x D x T
|
||||
y_lengths: B
|
||||
Shapes:
|
||||
x: B x T
|
||||
x_lenghts: B
|
||||
y: B x C x T
|
||||
y_lengths: B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from tests import get_tests_input_path
|
||||
from torch import nn, optim
|
||||
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.models.glow_tts import GlowTts
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
#pylint: disable=unused-variable
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class GlowTTSTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
criterion = criterion = GlowTTSLoss()
|
||||
|
||||
# model to train
|
||||
model = GlowTts(
|
||||
num_chars=32,
|
||||
hidden_channels=128,
|
||||
filter_channels=32,
|
||||
filter_channels_dp=32,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"
|
||||
).to(device)
|
||||
|
||||
# reference model to compare model weights
|
||||
model_ref = GlowTts(
|
||||
num_chars=32,
|
||||
hidden_channels=128,
|
||||
filter_channels=32,
|
||||
filter_channels_dp=32,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"
|
||||
).to(device)
|
||||
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" %
|
||||
(count_parameters(model)))
|
||||
|
||||
# pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||
optimizer.zero_grad()
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, input_lengths)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
Loading…
Reference in New Issue