From a2606fbc221989cf2f5e0ccb4dce470c31795df2 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 6 Oct 2020 11:02:54 +0200 Subject: [PATCH] format utils --- TTS/speaker_encoder/utils/__init__.py | 0 TTS/speaker_encoder/utils/generic_utils.py | 118 ++++++ TTS/speaker_encoder/utils/io.py | 0 TTS/speaker_encoder/utils/prepare_voxceleb.py | 233 +++++++++++ TTS/speaker_encoder/utils/visual.py | 46 +++ utils/generic_utils.py | 374 ------------------ 6 files changed, 397 insertions(+), 374 deletions(-) create mode 100644 TTS/speaker_encoder/utils/__init__.py create mode 100644 TTS/speaker_encoder/utils/generic_utils.py create mode 100644 TTS/speaker_encoder/utils/io.py create mode 100644 TTS/speaker_encoder/utils/prepare_voxceleb.py create mode 100644 TTS/speaker_encoder/utils/visual.py delete mode 100644 utils/generic_utils.py diff --git a/TTS/speaker_encoder/utils/__init__.py b/TTS/speaker_encoder/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py new file mode 100644 index 00000000..d2cc0ede --- /dev/null +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -0,0 +1,118 @@ +import datetime +import importlib +import os +import re + +import torch +from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.utils.generic_utils import check_argument + + +def to_camel(text): + text = text.capitalize() + return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + + +def setup_model(c): + model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'], + c.model['lstm_dim'], c.model['num_lstm_layers']) + return model + + +def save_checkpoint(model, optimizer, model_loss, out_path, + current_step, epoch): + checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) + checkpoint_path = os.path.join(out_path, checkpoint_path) + print(" | | > Checkpoint saving : {}".format(checkpoint_path)) + + new_state_dict = model.state_dict() + state = { + 'model': new_state_dict, + 'optimizer': optimizer.state_dict() if optimizer is not None else None, + 'step': current_step, + 'epoch': epoch, + 'loss': model_loss, + 'date': datetime.date.today().strftime("%B %d, %Y"), + } + torch.save(state, checkpoint_path) + + +def save_best_model(model, optimizer, model_loss, best_loss, out_path, + current_step): + if model_loss < best_loss: + new_state_dict = model.state_dict() + state = { + 'model': new_state_dict, + 'optimizer': optimizer.state_dict(), + 'step': current_step, + 'loss': model_loss, + 'date': datetime.date.today().strftime("%B %d, %Y"), + } + best_loss = model_loss + bestmodel_path = 'best_model.pth.tar' + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format( + model_loss, bestmodel_path)) + torch.save(state, bestmodel_path) + return best_loss + + +def check_config_speaker_encoder(c): + """Check the config.json file of the speaker encoder""" + check_argument('run_name', c, restricted=True, val_type=str) + check_argument('run_description', c, val_type=str) + + # audio processing parameters + check_argument('audio', c, restricted=True, val_type=dict) + check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') + check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # training parameters + check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) + check_argument('grad_clip', c, restricted=True, val_type=float) + check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + check_argument('lr', c, restricted=True, val_type=float, min_val=0) + check_argument('lr_decay', c, restricted=True, val_type=bool) + check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) + check_argument('num_loader_workers', c, restricted=True, val_type=int) + check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # checkpoint and output parameters + check_argument('steps_plot_stats', c, restricted=True, val_type=int) + check_argument('checkpoint', c, restricted=True, val_type=bool) + check_argument('save_step', c, restricted=True, val_type=int) + check_argument('print_step', c, restricted=True, val_type=int) + check_argument('output_path', c, restricted=True, val_type=str) + + # model parameters + check_argument('model', c, restricted=True, val_type=dict) + check_argument('input_dim', c['model'], restricted=True, val_type=int) + check_argument('proj_dim', c['model'], restricted=True, val_type=int) + check_argument('lstm_dim', c['model'], restricted=True, val_type=int) + check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) + check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) + + # in-memory storage parameters + check_argument('storage', c, restricted=True, val_type=dict) + check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) + check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # datasets - checking only the first entry + check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + check_argument('name', dataset_entry, restricted=True, val_type=str) + check_argument('path', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) + check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + diff --git a/TTS/speaker_encoder/utils/io.py b/TTS/speaker_encoder/utils/io.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/speaker_encoder/utils/prepare_voxceleb.py new file mode 100644 index 00000000..758e1cb3 --- /dev/null +++ b/TTS/speaker_encoder/utils/prepare_voxceleb.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +''' voxceleb 1 & 2 ''' + +import os +import sys +import zipfile +import subprocess +import hashlib +import pandas +from absl import logging +import tensorflow as tf +import soundfile as sf + +gfile = tf.compat.v1.gfile + +SUBSETS = { + "vox1_dev_wav": + ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"], + "vox1_test_wav": + ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": + ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"], + "vox2_test_aac": + ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"] +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312" +} + +USER = { + "user": "", + "password": "" +} + +speaker_id_dict = {} + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + if not gfile.Exists(directory): + gfile.MakeDirs(directory) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logging.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call('wget %s --user %s --password %s -O %s' % + (url, USER["user"], USER["password"], zip_filepath), shell=True) + + statinfo = os.stat(zip_filepath) + logging.info( + "Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size) + ) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call('cat %s* > %s.zip' % + (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True) + finally: + # gfile.Remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logging.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logging.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logging.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logging.error(f"Failed to decode aac file with retcode {ret}") + logging.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, + output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logging.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in gfile.Walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = (os.path.splitext(name)) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not gfile.Exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append( + (os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name) + ) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + df = pandas.DataFrame( + data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df.to_csv(csv_file_path, index=False, sep="\t") + logging.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """ download and process """ + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + '.csv') + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logging.info("Downloading and process the voxceleb in %s", directory) + logging.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label( + directory, + subset, + directory, + subset + ".csv" + ) + logging.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + logging.set_verbosity(logging.INFO) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/TTS/speaker_encoder/utils/visual.py b/TTS/speaker_encoder/utils/visual.py new file mode 100644 index 00000000..68c48f12 --- /dev/null +++ b/TTS/speaker_encoder/utils/visual.py @@ -0,0 +1,46 @@ +import umap +import numpy as np +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_utter_per_speaker): + embeddings = embeddings[: 10 * num_utter_per_speaker] + model = umap.UMAP() + projection = model.fit_transform(embeddings) + num_speakers = embeddings.shape[0] // num_utter_per_speaker + ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) + colors = [colormap[i] for i in ground_truth] + + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/utils/generic_utils.py b/utils/generic_utils.py deleted file mode 100644 index 3bb99e08..00000000 --- a/utils/generic_utils.py +++ /dev/null @@ -1,374 +0,0 @@ -import os -import glob -import torch -import shutil -import datetime -import subprocess -import importlib -import numpy as np -from collections import Counter - - -def get_git_branch(): - try: - out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") - if line.startswith("*")) - current.replace("* ", "") - except subprocess.CalledProcessError: - current = "inside_docker" - return current - - -def get_commit_hash(): - """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" - # try: - # subprocess.check_output(['git', 'diff-index', '--quiet', - # 'HEAD']) # Verify client is clean - # except: - # raise RuntimeError( - # " !! Commit before training to get the commit hash.") - try: - commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() - # Not copying .git folder into docker container - except subprocess.CalledProcessError: - commit = "0000000" - print(' > Git Hash: {}'.format(commit)) - return commit - - -def create_experiment_folder(root_path, model_name, debug): - """ Create a folder with the current date and time """ - date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") - if debug: - commit_hash = 'debug' - else: - commit_hash = get_commit_hash() - output_folder = os.path.join( - root_path, model_name + '-' + date_str + '-' + commit_hash) - os.makedirs(output_folder, exist_ok=True) - print(" > Experiment folder: {}".format(output_folder)) - return output_folder - - -def remove_experiment_folder(experiment_path): - """Check folder if there is a checkpoint, otherwise remove the folder""" - - checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") - if not checkpoint_files: - if os.path.exists(experiment_path): - shutil.rmtree(experiment_path, ignore_errors=True) - print(" ! Run is removed from {}".format(experiment_path)) - else: - print(" ! Run is kept in {}".format(experiment_path)) - - -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def split_dataset(items): - is_multi_speaker = False - speakers = [item[-1] for item in items] - is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = 500 if len(items) * 0.01 > 500 else int( - len(items) * 0.01) - assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." - np.random.seed(0) - np.random.shuffle(items) - if is_multi_speaker: - items_eval = [] - # most stupid code ever -- Fix it ! - while len(items_eval) < eval_split_size: - speakers = [item[-1] for item in items] - speaker_counter = Counter(speakers) - item_idx = np.random.randint(0, len(items)) - if speaker_counter[items[item_idx][-1]] > 1: - items_eval.append(items[item_idx]) - del items[item_idx] - return items_eval, items - return items[:eval_split_size], items[eval_split_size:] - - -# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 -def sequence_mask(sequence_length, max_len=None): - if max_len is None: - max_len = sequence_length.data.max() - batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).long() - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.to(sequence_length.device) - seq_length_expand = ( - sequence_length.unsqueeze(1).expand_as(seq_range_expand)) - # B x T_max - return seq_range_expand < seq_length_expand - - -def set_init_dict(model_dict, checkpoint_state, c): - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint_state.items(): - if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint_state.items() if k in model_dict - } - # 2. filter out different size layers - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if v.numel() == model_dict[k].numel() - } - # 3. skip reinit layers - if c.reinit_layers is not None: - for reinit_layer_name in c.reinit_layers: - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if reinit_layer_name not in k - } - # 4. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), - len(model_dict))) - return model_dict - - -def setup_model(num_chars, num_speakers, c): - print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('TTS.models.' + c.model.lower()) - MyModel = getattr(MyModel, c.model) - if c.model.lower() in "tacotron": - model = MyModel(num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - memory_size=c.memory_size, - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) - elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) - return model - -class KeepAverage(): - def __init__(self): - self.avg_values = {} - self.iters = {} - - def __getitem__(self, key): - return self.avg_values[key] - - def items(self): - return self.avg_values.items() - - def add_value(self, name, init_val=0, init_iter=0): - self.avg_values[name] = init_val - self.iters[name] = init_iter - - def update_value(self, name, value, weighted_avg=False): - if name not in self.avg_values: - # add value if not exist before - self.add_value(name, init_val=value) - else: - # else update existing value - if weighted_avg: - self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value - self.iters[name] += 1 - else: - self.avg_values[name] = self.avg_values[name] * \ - self.iters[name] + value - self.iters[name] += 1 - self.avg_values[name] /= self.iters[name] - - def add_values(self, name_dict): - for key, value in name_dict.items(): - self.add_value(key, init_val=value) - - def update_values(self, value_dict): - for key, value in value_dict.items(): - self.update_value(key, value) - - -def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): - if alternative in c.keys() and c[alternative] is not None: - return - if restricted: - assert name in c.keys(), f' [!] {name} not defined in config.json' - if name in c.keys(): - if max_val: - assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' - if min_val: - assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' - if enum_list: - assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' - if val_type: - assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' - - -def check_config(c): - _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) - _check_argument('run_name', c, restricted=True, val_type=str) - _check_argument('run_description', c, val_type=str) - - # AUDIO - _check_argument('audio', c, restricted=True, val_type=dict) - - # audio processing parameters - _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - _check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) - - # vocabulary parameters - _check_argument('characters', c, restricted=False, val_type=dict) - _check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - _check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - _check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - _check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - _check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - _check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - - # normalization parameters - _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) - _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) - _check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) - _check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) - _check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) - _check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - _check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100) - _check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) - _check_argument('trim_db', c['audio'], restricted=True, val_type=int) - - # training parameters - _check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) - _check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) - _check_argument('r', c, restricted=True, val_type=int, min_val=1) - _check_argument('gradual_training', c, restricted=False, val_type=list) - _check_argument('loss_masking', c, restricted=True, val_type=bool) - # _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) - - # validation parameters - _check_argument('run_eval', c, restricted=True, val_type=bool) - _check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) - _check_argument('test_sentences_file', c, restricted=False, val_type=str) - - # optimizer - _check_argument('noam_schedule', c, restricted=False, val_type=bool) - _check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) - _check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - _check_argument('lr', c, restricted=True, val_type=float, min_val=0) - _check_argument('wd', c, restricted=True, val_type=float, min_val=0) - _check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - _check_argument('seq_len_norm', c, restricted=True, val_type=bool) - - # tacotron prenet - _check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) - _check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) - _check_argument('prenet_dropout', c, restricted=True, val_type=bool) - - # attention - _check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) - _check_argument('attention_heads', c, restricted=True, val_type=int) - _check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) - _check_argument('windowing', c, restricted=True, val_type=bool) - _check_argument('use_forward_attn', c, restricted=True, val_type=bool) - _check_argument('forward_attn_mask', c, restricted=True, val_type=bool) - _check_argument('transition_agent', c, restricted=True, val_type=bool) - _check_argument('transition_agent', c, restricted=True, val_type=bool) - _check_argument('location_attn', c, restricted=True, val_type=bool) - _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) - _check_argument('double_decoder_consistency', c, restricted=True, val_type=bool) - _check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) - - # stopnet - _check_argument('stopnet', c, restricted=True, val_type=bool) - _check_argument('separate_stopnet', c, restricted=True, val_type=bool) - - # tensorboard - _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) - _check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) - _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) - _check_argument('checkpoint', c, restricted=True, val_type=bool) - _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) - - # dataloading - # pylint: disable=import-outside-toplevel - from TTS.utils.text import cleaners - _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) - _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) - _check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) - _check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) - _check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) - _check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) - _check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) - - # paths - _check_argument('output_path', c, restricted=True, val_type=str) - - # multi-speaker - _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - - # GST - _check_argument('use_gst', c, restricted=True, val_type=bool) - _check_argument('gst', c, restricted=True, val_type=dict) - _check_argument('gst_style_input', c['gst'], restricted=True, val_type=str) - _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) - _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) - _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1) - - # datasets - checking only the first entry - _check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - _check_argument('name', dataset_entry, restricted=True, val_type=str) - _check_argument('path', dataset_entry, restricted=True, val_type=str) - _check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) - _check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)