From 7029452228edc6731013e0afcdd5a2f289d978df Mon Sep 17 00:00:00 2001 From: mueller91 Date: Tue, 22 Sep 2020 10:31:42 +0200 Subject: [PATCH 1/8] fix: make speaker encoder's storage parameters non-restriced --- TTS/tts/utils/generic_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 6358e5a9..095ddc07 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -140,10 +140,10 @@ def check_config(c): check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) check_argument('trim_db', c['audio'], restricted=True, val_type=int) - # storage parameters - check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + # storage parameters (only for speaker encoder) + check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) + check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100) + check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) # training parameters check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) From 0ea7f4e2bd9a90ce9f5724915ef5c774d7a6a655 Mon Sep 17 00:00:00 2001 From: mueller91 Date: Tue, 22 Sep 2020 10:39:40 +0200 Subject: [PATCH 2/8] fix: make speaker encoder's storage parameters non-restriced --- TTS/tts/utils/generic_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 095ddc07..af32a769 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -141,9 +141,10 @@ def check_config(c): check_argument('trim_db', c['audio'], restricted=True, val_type=int) # storage parameters (only for speaker encoder) - check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) + if 'storage' in c.keys(): + check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) + check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100) + check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) # training parameters check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) From df4caec4b773377a1285e3133bd24be0ba0b11b9 Mon Sep 17 00:00:00 2001 From: mueller91 Date: Tue, 22 Sep 2020 19:52:09 +0200 Subject: [PATCH 3/8] add: check_config for speaker_encoder --- TTS/bin/train_encoder.py | 2 ++ TTS/bin/train_tts.py | 4 +-- TTS/speaker_encoder/utils.py | 61 ++++++++++++++++++++++++++++++++++ TTS/tts/utils/generic_utils.py | 8 +---- requirements.txt | 1 + 5 files changed, 67 insertions(+), 9 deletions(-) create mode 100644 TTS/speaker_encoder/utils.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 3222c278..57a17704 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -14,6 +14,7 @@ from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.generic_utils import save_best_model from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.speaker_encoder.utils import check_config_speaker_encoder from TTS.speaker_encoder.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.generic_utils import ( @@ -235,6 +236,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) + check_config_speaker_encoder(c) _ = os.path.dirname(os.path.realpath(__file__)) if args.data_path != '': c.data_path = args.data_path diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index f2641f9d..1b7351d4 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -17,7 +17,7 @@ from TTS.tts.layers.losses import TacotronLoss from TTS.tts.utils.distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) -from TTS.tts.utils.generic_utils import check_config, setup_model +from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, @@ -670,7 +670,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) - check_config(c) + check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) if c.apex_amp_level == 'O1': diff --git a/TTS/speaker_encoder/utils.py b/TTS/speaker_encoder/utils.py new file mode 100644 index 00000000..95c222f2 --- /dev/null +++ b/TTS/speaker_encoder/utils.py @@ -0,0 +1,61 @@ +from TTS.utils.generic_utils import check_argument + + +def check_config_speaker_encoder(c): + """Check the config.json file of the speaker encoder""" + check_argument('run_name', c, restricted=True, val_type=str) + check_argument('run_description', c, val_type=str) + + # audio processing parameters + check_argument('audio', c, restricted=True, val_type=dict) + check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') + check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # training parameters + check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) + check_argument('grad_clip', c, restricted=True, val_type=float) + check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + check_argument('lr', c, restricted=True, val_type=float, min_val=0) + check_argument('lr_decay', c, restricted=True, val_type=bool) + check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) + check_argument('num_loader_workers', c, restricted=True, val_type=int) + check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # checkpoint and output parameters + check_argument('steps_plot_stats', c, restricted=True, val_type=int) + check_argument('checkpoint', c, restricted=True, val_type=bool) + check_argument('save_step', c, restricted=True, val_type=int) + check_argument('print_step', c, restricted=True, val_type=int) + check_argument('output_path', c, restricted=True, val_type=str) + + # model parameters + check_argument('model', c, restricted=True, val_type=dict) + check_argument('input_dim', c['model'], restricted=True, val_type=int) + check_argument('proj_dim', c['model'], restricted=True, val_type=int) + check_argument('lstm_dim', c['model'], restricted=True, val_type=int) + check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) + check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) + + # in-memory storage parameters + check_argument('storage', c, restricted=True, val_type=dict) + check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) + check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # datasets - checking only the first entry + check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + check_argument('name', dataset_entry, restricted=True, val_type=str) + check_argument('path', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) + check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index af32a769..f9d21644 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -100,7 +100,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model -def check_config(c): +def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str) @@ -140,12 +140,6 @@ def check_config(c): check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) check_argument('trim_db', c['audio'], restricted=True, val_type=int) - # storage parameters (only for speaker encoder) - if 'storage' in c.keys(): - check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) - # training parameters check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) diff --git a/requirements.txt b/requirements.txt index f0f2c057..85ffbe55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ nose==1.3.7 cardboardlint==1.3.0 pylint==2.5.3 gdown +umap \ No newline at end of file From cfeeef7a7f8ad828e4e4ff34081e61103bee3d65 Mon Sep 17 00:00:00 2001 From: mueller91 Date: Tue, 22 Sep 2020 20:10:41 +0200 Subject: [PATCH 4/8] fix: broken imports and missing files after merging in latest commits from mozilla/dev into mueller91/dev. speaker_encoder's config.json and visuals.py are missing in the current dev branch of MozillaTTS, and some imports are broken. --- TTS/bin/train_encoder.py | 5 +- TTS/speaker_encoder/config.json | 103 ++++++++++++++++++++++++++++++++ TTS/speaker_encoder/visuals.py | 46 ++++++++++++++ 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 TTS/speaker_encoder/config.json create mode 100644 TTS/speaker_encoder/visuals.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index cb5e67b9..95030309 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -11,13 +11,12 @@ import torch from torch.utils.data import DataLoader from TTS.speaker_encoder.dataset import MyDataset -from TTS.speaker_encoder.utils.generic_utils import save_best_model from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from TTS.speaker_encoder.model import SpeakerEncoder -from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.speaker_encoder.utils import check_config_speaker_encoder -from TTS.speaker_encoder.visual import plot_embeddings +from TTS.speaker_encoder.visuals import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.utils.io import save_best_model from TTS.utils.generic_utils import ( create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) diff --git a/TTS/speaker_encoder/config.json b/TTS/speaker_encoder/config.json new file mode 100644 index 00000000..4fbd84cc --- /dev/null +++ b/TTS/speaker_encoder/config.json @@ -0,0 +1,103 @@ + +{ + "run_name": "mueller91", + "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", + "audio":{ + // Audio processing parameters + "num_mels": 40, // size of the mel spec frame. + "fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "win_length": 400, // stft window length in ms. + "hop_length": 160, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. + }, + "reinit_layers": [], + "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) + "grad_clip": 3.0, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "steps_plot_stats": 10, // number of steps to plot embeddings. + "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "num_utters_per_speaker": 10, // + "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 20, // Number of steps to log traning on console. + "output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. + "model": { + "input_dim": 40, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": true + }, + "storage": { + "sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 15, // the size of the in-memory storage with respect to a single batch + "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness + }, + "datasets": + [ + { + "name": "vctk_slim", + "path": "../../../audio-datasets/en/VCTK-Corpus/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-other-500", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb1", + "path": "../../../audio-datasets/en/voxceleb1/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb2", + "path": "../../../audio-datasets/en/voxceleb2/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "common_voice", + "path": "../../../audio-datasets/en/MozillaCommonVoice", + "meta_file_train": "train.tsv", + "meta_file_val": "test.tsv" + } + ] +} \ No newline at end of file diff --git a/TTS/speaker_encoder/visuals.py b/TTS/speaker_encoder/visuals.py new file mode 100644 index 00000000..68c48f12 --- /dev/null +++ b/TTS/speaker_encoder/visuals.py @@ -0,0 +1,46 @@ +import umap +import numpy as np +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_utter_per_speaker): + embeddings = embeddings[: 10 * num_utter_per_speaker] + model = umap.UMAP() + projection = model.fit_transform(embeddings) + num_speakers = embeddings.shape[0] // num_utter_per_speaker + ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) + colors = [colormap[i] for i in ground_truth] + + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig From 227b9c886429f3f8f8074f5377646cd5bbc58ba8 Mon Sep 17 00:00:00 2001 From: mueller91 Date: Wed, 23 Sep 2020 23:27:51 +0200 Subject: [PATCH 5/8] fix: split_dataset() runtime reduced from O(N * |items|) to O(N) where N is the size of the eval split (max 500) I notice a significant speedup on the initial loading of large datasets such as common voice (from minutes to seconds) --- TTS/tts/utils/generic_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index a0aba29a..2b165951 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -16,13 +16,14 @@ def split_dataset(items): np.random.shuffle(items) if is_multi_speaker: items_eval = [] - # most stupid code ever -- Fix it ! + speakers = [item[-1] for item in items] + speaker_counter = Counter(speakers) while len(items_eval) < eval_split_size: - speakers = [item[-1] for item in items] - speaker_counter = Counter(speakers) item_idx = np.random.randint(0, len(items)) - if speaker_counter[items[item_idx][-1]] > 1: + speaker_to_be_removed = items[item_idx][-1] + if speaker_counter[speaker_to_be_removed] > 1: items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 del items[item_idx] return items_eval, items return items[:eval_split_size], items[eval_split_size:] @@ -127,6 +128,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model + def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) From 665f7ca714c3d1f32449b48d0149f54482989aed Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 24 Sep 2020 12:57:54 +0200 Subject: [PATCH 6/8] linter fix --- TTS/tts/layers/glow_tts/glow.py | 66 +------------------ TTS/tts/layers/glow_tts/normalization.py | 9 ++- .../layers/glow_tts/time_depth_sep_conv.py | 4 +- TTS/tts/layers/losses.py | 66 +++++++++++-------- .../models/fullband_melgan_generator.py | 15 ++--- 5 files changed, 53 insertions(+), 107 deletions(-) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index acfc55e5..b06dd8a5 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -69,7 +69,7 @@ class WN(torch.nn.Module): num_layers, c_in_channels=0, dropout_p=0): - super(WN, self).__init__() + super().__init__() assert kernel_size % 2 == 1 assert hidden_channels % 2 == 0 self.in_channels = in_channels @@ -148,70 +148,6 @@ class WN(torch.nn.Module): for l in self.res_skip_layers: torch.nn.utils.remove_weight_norm(l) - -class ActNorm(nn.Module): - """Activation Normalization bijector as an alternative to Batch Norm. It computes - mean and std from a sample data in advance and it uses these values - for normalization at training. - - Args: - channels (int): input channels. - ddi (False): data depended initialization flag. - - Shapes: - - inputs: (B, C, T) - - outputs: (B, C, T) - """ - - def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument - super().__init__() - self.channels = channels - self.initialized = not ddi - - self.logs = nn.Parameter(torch.zeros(1, channels, 1)) - self.bias = nn.Parameter(torch.zeros(1, channels, 1)) - - def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument - if x_mask is None: - x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, - dtype=x.dtype) - x_len = torch.sum(x_mask, [1, 2]) - if not self.initialized: - self.initialize(x, x_mask) - self.initialized = True - - if reverse: - z = (x - self.bias) * torch.exp(-self.logs) * x_mask - logdet = None - else: - z = (self.bias + torch.exp(self.logs) * x) * x_mask - logdet = torch.sum(self.logs) * x_len # [b] - - return z, logdet - - def store_inverse(self): - pass - - def set_ddi(self, ddi): - self.initialized = not ddi - - def initialize(self, x, x_mask): - with torch.no_grad(): - denom = torch.sum(x_mask, [0, 2]) - m = torch.sum(x * x_mask, [0, 2]) / denom - m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m**2) - logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) - - bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( - dtype=self.bias.dtype) - logs_init = (-logs).view(*self.logs.shape).to( - dtype=self.logs.dtype) - - self.bias.data.copy_(bias_init) - self.logs.data.copy_(logs_init) - - class InvConvNear(nn.Module): def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument super().__init__() diff --git a/TTS/tts/layers/glow_tts/normalization.py b/TTS/tts/layers/glow_tts/normalization.py index 0930f48c..5ccdeb47 100644 --- a/TTS/tts/layers/glow_tts/normalization.py +++ b/TTS/tts/layers/glow_tts/normalization.py @@ -36,11 +36,10 @@ class TemporalBatchNorm1d(nn.BatchNorm1d): affine=True, track_running_stats=True, momentum=0.1): - super(TemporalBatchNorm1d, - self).__init__(channels, - affine=affine, - track_running_stats=track_running_stats, - momentum=momentum) + super().__init__(channels, + affine=affine, + track_running_stats=track_running_stats, + momentum=momentum) def forward(self, x): return super().forward(x.transpose(2, 1)).transpose(2, 1) diff --git a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py index 732e7d96..c9a117c8 100644 --- a/TTS/tts/layers/glow_tts/time_depth_sep_conv.py +++ b/TTS/tts/layers/glow_tts/time_depth_sep_conv.py @@ -11,7 +11,7 @@ class TimeDepthSeparableConv(nn.Module): out_channels, kernel_size, bias=True): - super(TimeDepthSeparableConv, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -69,7 +69,7 @@ class TimeDepthSeparableConvBlock(nn.Module): num_layers, kernel_size, bias=True): - super(TimeDepthSeparableConvBlock, self).__init__() + super().__init__() assert (kernel_size - 1) % 2 == 0 assert num_layers > 1 diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 4bc31d90..bf03671c 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -7,9 +7,8 @@ from TTS.tts.utils.generic_utils import sequence_mask class L1LossMasked(nn.Module): - def __init__(self, seq_len_norm): - super(L1LossMasked, self).__init__() + super().__init__() self.seq_len_norm = seq_len_norm def forward(self, x, target, length): @@ -28,25 +27,24 @@ class L1LossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask( - sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.l1_loss( - x * mask, target * mask, reduction='none') + loss = functional.l1_loss(x * mask, + target * mask, + reduction='none') loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.l1_loss( - x * mask, target * mask, reduction='sum') + loss = functional.l1_loss(x * mask, target * mask, reduction='sum') loss = loss / mask.sum() return loss class MSELossMasked(nn.Module): - def __init__(self, seq_len_norm): super(MSELossMasked, self).__init__() self.seq_len_norm = seq_len_norm @@ -67,19 +65,21 @@ class MSELossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask( - sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.mse_loss( - x * mask, target * mask, reduction='none') + loss = functional.mse_loss(x * mask, + target * mask, + reduction='none') loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.mse_loss( - x * mask, target * mask, reduction='sum') + loss = functional.mse_loss(x * mask, + target * mask, + reduction='sum') loss = loss / mask.sum() return loss @@ -100,7 +100,6 @@ class AttentionEntropyLoss(nn.Module): class BCELossMasked(nn.Module): - def __init__(self, pos_weight): super(BCELossMasked, self).__init__() self.pos_weight = pos_weight @@ -121,9 +120,13 @@ class BCELossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() + mask = sequence_mask(sequence_length=length, + max_len=target.size(1)).float() loss = functional.binary_cross_entropy_with_logits( - x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') + x * mask, + target * mask, + pos_weight=self.pos_weight, + reduction='sum') loss = loss / mask.sum() return loss @@ -139,7 +142,8 @@ class GuidedAttentionLoss(torch.nn.Module): max_olen = max(olens) ga_masks = torch.zeros((B, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): - ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) + ga_masks[idx, :olen, :ilen] = self._make_ga_mask( + ilen, olen, self.sigma) return ga_masks def forward(self, att_ws, ilens, olens): @@ -153,7 +157,8 @@ class GuidedAttentionLoss(torch.nn.Module): def _make_ga_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() - return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2))) + return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 / + (2 * (sigma**2))) @staticmethod def _make_masks(ilens, olens): @@ -181,7 +186,8 @@ class TacotronLoss(torch.nn.Module): self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) # stopnet loss # pylint: disable=not-callable - self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None + self.criterion_st = BCELossMasked( + pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None def forward(self, postnet_output, decoder_output, mel_input, linear_input, stopnet_output, stopnet_target, output_lens, decoder_b_output, @@ -219,19 +225,25 @@ class TacotronLoss(torch.nn.Module): # backward decoder loss (if enabled) if self.config.bidirectional_decoder: if self.config.loss_masking: - decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input, output_lens) + decoder_b_loss = self.criterion( + torch.flip(decoder_b_output, dims=(1, )), mel_input, + output_lens) else: - decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output) + decoder_b_loss = self.criterion( + torch.flip(decoder_b_output, dims=(1, )), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss( + torch.flip(decoder_b_output, dims=(1, )), decoder_output) loss += decoder_b_loss + decoder_c_loss return_dict['decoder_b_loss'] = decoder_b_loss return_dict['decoder_c_loss'] = decoder_c_loss # double decoder consistency loss (if enabled) if self.config.double_decoder_consistency: - decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) + decoder_b_loss = self.criterion(decoder_b_output, mel_input, + output_lens) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) - attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + attention_c_loss = torch.nn.functional.l1_loss( + alignments, alignments_backwards) loss += decoder_b_loss + attention_c_loss return_dict['decoder_coarse_loss'] = decoder_b_loss return_dict['decoder_ddc_loss'] = attention_c_loss @@ -248,7 +260,7 @@ class TacotronLoss(torch.nn.Module): class GlowTTSLoss(torch.nn.Module): def __init__(self): - super(GlowTTSLoss, self).__init__() + super().__init__() self.constant_factor = 0.5 * math.log(2 * math.pi) def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py index 9f90ee17..52dcc75e 100644 --- a/TTS/vocoder/models/fullband_melgan_generator.py +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -12,14 +12,13 @@ class FullbandMelganGenerator(MelganGenerator): upsample_factors=(2, 8, 2, 2), res_kernel=3, num_res_blocks=4): - super(FullbandMelganGenerator, - self).__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + super().__init__(in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks) @torch.no_grad() def inference(self, cond_features): From 6a70c63f242c49f2945d1f97b71efb0cdfb138c2 Mon Sep 17 00:00:00 2001 From: erogol Date: Sun, 27 Sep 2020 03:28:42 +0200 Subject: [PATCH 7/8] correct glow-tts loss --- TTS/tts/layers/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index bf03671c..15b0a2ee 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -270,7 +270,7 @@ class GlowTTSLoss(torch.nn.Module): pz = torch.sum(scales) + 0.5 * torch.sum( torch.exp(-2 * scales) * (z - means)**2) log_mle = self.constant_factor + (pz - torch.sum(log_det)) / ( - torch.sum(y_lengths // 2) * 2 * z.shape[1]) + torch.sum(y_lengths) * z.shape[1]) # duration loss - MSE # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) # duration loss - huber loss From 154f90bc44f4a3287e599971239075ad3d1dfea8 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 28 Sep 2020 11:19:19 +0200 Subject: [PATCH 8/8] format speaker encoder imports --- TTS/bin/train_encoder.py | 17 ++-- TTS/speaker_encoder/compute_embeddings.py | 88 ------------------- TTS/speaker_encoder/utils.py | 61 ------------- TTS/speaker_encoder/visuals.py | 46 ---------- .../configs/multiband_melgan_config.json | 5 +- 5 files changed, 9 insertions(+), 208 deletions(-) delete mode 100644 TTS/speaker_encoder/compute_embeddings.py delete mode 100644 TTS/speaker_encoder/utils.py delete mode 100644 TTS/speaker_encoder/visuals.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 95030309..8d1f14fa 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -9,20 +9,19 @@ import traceback import torch from torch.utils.data import DataLoader - from TTS.speaker_encoder.dataset import MyDataset -from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss +from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss from TTS.speaker_encoder.model import SpeakerEncoder -from TTS.speaker_encoder.utils import check_config_speaker_encoder -from TTS.speaker_encoder.visuals import plot_embeddings +from TTS.speaker_encoder.utils.generic_utils import \ + check_config_speaker_encoder +from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.utils.io import save_best_model -from TTS.utils.generic_utils import ( - create_experiment_folder, get_git_branch, remove_experiment_folder, - set_init_dict) -from TTS.utils.io import copy_config_file, load_config from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import count_parameters +from TTS.utils.generic_utils import (count_parameters, + create_experiment_folder, get_git_branch, + remove_experiment_folder, set_init_dict) +from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import NoamLR, check_update diff --git a/TTS/speaker_encoder/compute_embeddings.py b/TTS/speaker_encoder/compute_embeddings.py deleted file mode 100644 index c8608755..00000000 --- a/TTS/speaker_encoder/compute_embeddings.py +++ /dev/null @@ -1,88 +0,0 @@ -import argparse -import glob -import os - -import numpy as np -from tqdm import tqdm - -import torch -from TTS.speaker_encoder.model import SpeakerEncoder -from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config - -parser = argparse.ArgumentParser( - description='Compute embedding vectors for each wav file in a dataset. ') -parser.add_argument( - 'model_path', - type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).') -parser.add_argument( - 'config_path', - type=str, - help='Path to config file for training.', -) -parser.add_argument( - 'data_path', - type=str, - help='Data path for wav files - directory or CSV file') -parser.add_argument( - 'output_path', - type=str, - help='path for training outputs.') -parser.add_argument( - '--use_cuda', type=bool, help='flag to set cuda.', default=False -) -parser.add_argument( - '--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|' -) -args = parser.parse_args() - - -c = load_config(args.config_path) -ap = AudioProcessor(**c['audio']) - -data_path = args.data_path -split_ext = os.path.splitext(data_path) -sep = args.separator - -if len(split_ext) > 0 and split_ext[1].lower() == '.csv': - # Parse CSV - print(f'CSV file: {data_path}') - with open(data_path) as f: - wav_path = os.path.join(os.path.dirname(data_path), 'wavs') - wav_files = [] - print(f'Separator is: {sep}') - for line in f: - components = line.split(sep) - if len(components) != 2: - print("Invalid line") - continue - wav_file = os.path.join(wav_path, components[0] + '.wav') - #print(f'wav_file: {wav_file}') - if os.path.exists(wav_file): - wav_files.append(wav_file) - print(f'Count of wavs imported: {len(wav_files)}') -else: - # Parse all wav files in data_path - wav_path = data_path - wav_files = glob.glob(data_path + '/**/*.wav', recursive=True) - -output_files = [wav_file.replace(wav_path, args.output_path).replace( - '.wav', '.npy') for wav_file in wav_files] - -for output_file in output_files: - os.makedirs(os.path.dirname(output_file), exist_ok=True) - -model = SpeakerEncoder(**c.model) -model.load_state_dict(torch.load(args.model_path)['model']) -model.eval() -if args.use_cuda: - model.cuda() - -for idx, wav_file in enumerate(tqdm(wav_files)): - mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T - mel_spec = torch.FloatTensor(mel_spec[None, :, :]) - if args.use_cuda: - mel_spec = mel_spec.cuda() - embedd = model.compute_embedding(mel_spec) - np.save(output_files[idx], embedd.detach().cpu().numpy()) diff --git a/TTS/speaker_encoder/utils.py b/TTS/speaker_encoder/utils.py deleted file mode 100644 index 95c222f2..00000000 --- a/TTS/speaker_encoder/utils.py +++ /dev/null @@ -1,61 +0,0 @@ -from TTS.utils.generic_utils import check_argument - - -def check_config_speaker_encoder(c): - """Check the config.json file of the speaker encoder""" - check_argument('run_name', c, restricted=True, val_type=str) - check_argument('run_description', c, val_type=str) - - # audio processing parameters - check_argument('audio', c, restricted=True, val_type=dict) - check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) - - # training parameters - check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) - check_argument('grad_clip', c, restricted=True, val_type=float) - check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('lr_decay', c, restricted=True, val_type=bool) - check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) - check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) - check_argument('num_loader_workers', c, restricted=True, val_type=int) - check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) - - # checkpoint and output parameters - check_argument('steps_plot_stats', c, restricted=True, val_type=int) - check_argument('checkpoint', c, restricted=True, val_type=bool) - check_argument('save_step', c, restricted=True, val_type=int) - check_argument('print_step', c, restricted=True, val_type=int) - check_argument('output_path', c, restricted=True, val_type=str) - - # model parameters - check_argument('model', c, restricted=True, val_type=dict) - check_argument('input_dim', c['model'], restricted=True, val_type=int) - check_argument('proj_dim', c['model'], restricted=True, val_type=int) - check_argument('lstm_dim', c['model'], restricted=True, val_type=int) - check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) - check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) - - # in-memory storage parameters - check_argument('storage', c, restricted=True, val_type=dict) - check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) - - # datasets - checking only the first entry - check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - check_argument('name', dataset_entry, restricted=True, val_type=str) - check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/TTS/speaker_encoder/visuals.py b/TTS/speaker_encoder/visuals.py deleted file mode 100644 index 68c48f12..00000000 --- a/TTS/speaker_encoder/visuals.py +++ /dev/null @@ -1,46 +0,0 @@ -import umap -import numpy as np -import matplotlib -import matplotlib.pyplot as plt - -matplotlib.use("Agg") - - -colormap = ( - np.array( - [ - [76, 255, 0], - [0, 127, 70], - [255, 0, 0], - [255, 217, 38], - [0, 135, 255], - [165, 0, 165], - [255, 167, 255], - [0, 255, 255], - [255, 96, 38], - [142, 76, 0], - [33, 0, 127], - [0, 0, 0], - [183, 183, 183], - ], - dtype=np.float, - ) - / 255 -) - - -def plot_embeddings(embeddings, num_utter_per_speaker): - embeddings = embeddings[: 10 * num_utter_per_speaker] - model = umap.UMAP() - projection = model.fit_transform(embeddings) - num_speakers = embeddings.shape[0] // num_utter_per_speaker - ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) - colors = [colormap[i] for i in ground_truth] - - fig, ax = plt.subplots(figsize=(16, 10)) - _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) - plt.gca().set_aspect("equal", "datalim") - plt.title("UMAP projection") - plt.tight_layout() - plt.savefig("umap") - return fig diff --git a/TTS/vocoder/configs/multiband_melgan_config.json b/TTS/vocoder/configs/multiband_melgan_config.json index a89d43bb..7a5a13e3 100644 --- a/TTS/vocoder/configs/multiband_melgan_config.json +++ b/TTS/vocoder/configs/multiband_melgan_config.json @@ -40,12 +40,9 @@ // "url": "tcp:\/\/localhost:54321" // }, - // MODEL PARAMETERS - "use_pqmf": true, - // LOSS PARAMETERS "use_stft_loss": true, - "use_subband_stft_loss": true, + "use_subband_stft_loss": true, // use only with multi-band models. "use_mse_gan_loss": true, "use_hinge_gan_loss": false, "use_feat_match_loss": false, // use only with melgan discriminators