mirror of https://github.com/coqui-ai/TTS.git
format utils
This commit is contained in:
parent
4873601694
commit
a2606fbc22
|
@ -0,0 +1,118 @@
|
|||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.utils.generic_utils import check_argument
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(c):
|
||||
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
|
||||
c.model['lstm_dim'], c.model['num_lstm_layers'])
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||
current_step, epoch):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = 'best_model.pth.tar'
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
|
||||
model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
||||
|
||||
def check_config_speaker_encoder(c):
|
||||
"""Check the config.json file of the speaker encoder"""
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
|
||||
# audio processing parameters
|
||||
check_argument('audio', c, restricted=True, val_type=dict)
|
||||
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# training parameters
|
||||
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
|
||||
check_argument('grad_clip', c, restricted=True, val_type=float)
|
||||
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('lr_decay', c, restricted=True, val_type=bool)
|
||||
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
|
||||
check_argument('num_loader_workers', c, restricted=True, val_type=int)
|
||||
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# checkpoint and output parameters
|
||||
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
|
||||
check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
check_argument('save_step', c, restricted=True, val_type=int)
|
||||
check_argument('print_step', c, restricted=True, val_type=int)
|
||||
check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# model parameters
|
||||
check_argument('model', c, restricted=True, val_type=dict)
|
||||
check_argument('input_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
|
||||
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
|
||||
|
||||
# in-memory storage parameters
|
||||
check_argument('storage', c, restricted=True, val_type=dict)
|
||||
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
|
|
@ -0,0 +1,233 @@
|
|||
# coding=utf-8
|
||||
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Only support eager mode and TF>=2.0.0
|
||||
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
|
||||
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
|
||||
''' voxceleb 1 & 2 '''
|
||||
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
import subprocess
|
||||
import hashlib
|
||||
import pandas
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
|
||||
gfile = tf.compat.v1.gfile
|
||||
|
||||
SUBSETS = {
|
||||
"vox1_dev_wav":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"],
|
||||
"vox1_test_wav":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
|
||||
"vox2_dev_aac":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"],
|
||||
"vox2_test_aac":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"]
|
||||
}
|
||||
|
||||
MD5SUM = {
|
||||
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
|
||||
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
|
||||
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
|
||||
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312"
|
||||
}
|
||||
|
||||
USER = {
|
||||
"user": "",
|
||||
"password": ""
|
||||
}
|
||||
|
||||
speaker_id_dict = {}
|
||||
|
||||
def download_and_extract(directory, subset, urls):
|
||||
"""Download and extract the given split of dataset.
|
||||
|
||||
Args:
|
||||
directory: the directory where to put the downloaded data.
|
||||
subset: subset name of the corpus.
|
||||
urls: the list of urls to download the data file.
|
||||
"""
|
||||
if not gfile.Exists(directory):
|
||||
gfile.MakeDirs(directory)
|
||||
|
||||
try:
|
||||
for url in urls:
|
||||
zip_filepath = os.path.join(directory, url.split("/")[-1])
|
||||
if os.path.exists(zip_filepath):
|
||||
continue
|
||||
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
subprocess.call('wget %s --user %s --password %s -O %s' %
|
||||
(url, USER["user"], USER["password"], zip_filepath), shell=True)
|
||||
|
||||
statinfo = os.stat(zip_filepath)
|
||||
logging.info(
|
||||
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)
|
||||
)
|
||||
|
||||
# concatenate all parts into zip files
|
||||
if ".zip" not in zip_filepath:
|
||||
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
|
||||
subprocess.call('cat %s* > %s.zip' %
|
||||
(zip_filepath, zip_filepath), shell=True)
|
||||
zip_filepath += ".zip"
|
||||
extract_path = zip_filepath.strip(".zip")
|
||||
|
||||
# check zip file md5sum
|
||||
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest()
|
||||
if md5 != MD5SUM[subset]:
|
||||
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||
|
||||
with zipfile.ZipFile(zip_filepath, "r") as zfile:
|
||||
zfile.extractall(directory)
|
||||
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
|
||||
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True)
|
||||
finally:
|
||||
# gfile.Remove(zip_filepath)
|
||||
pass
|
||||
|
||||
|
||||
def exec_cmd(cmd):
|
||||
"""Run a command in a subprocess.
|
||||
Args:
|
||||
cmd: command line to be executed.
|
||||
Return:
|
||||
int, the return code.
|
||||
"""
|
||||
try:
|
||||
retcode = subprocess.call(cmd, shell=True)
|
||||
if retcode < 0:
|
||||
logging.info(f"Child was terminated by signal {retcode}")
|
||||
except OSError as e:
|
||||
logging.info(f"Execution failed: {e}")
|
||||
retcode = -999
|
||||
return retcode
|
||||
|
||||
|
||||
def decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||
"""Decode a given AAC file into WAV using ffmpeg.
|
||||
Args:
|
||||
aac_file: file path to input AAC file.
|
||||
wav_file: file path to output WAV file.
|
||||
Return:
|
||||
bool, True if success.
|
||||
"""
|
||||
cmd = f"ffmpeg -i {aac_file} {wav_file}"
|
||||
logging.info(f"Decoding aac file using command line: {cmd}")
|
||||
ret = exec_cmd(cmd)
|
||||
if ret != 0:
|
||||
logging.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logging.error("Please check your ffmpeg installation.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_audio_and_make_label(input_dir, subset,
|
||||
output_dir, output_file):
|
||||
"""Optionally convert AAC to WAV and make speaker labels.
|
||||
Args:
|
||||
input_dir: the directory which holds the input dataset.
|
||||
subset: the name of the specified subset. e.g. vox1_dev_wav
|
||||
output_dir: the directory to place the newly generated csv files.
|
||||
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing audio and label for subset %s" % subset)
|
||||
source_dir = os.path.join(input_dir, subset)
|
||||
|
||||
files = []
|
||||
# Convert all AAC file into WAV format. At the same time, generate the csv
|
||||
for root, _, filenames in gfile.Walk(source_dir):
|
||||
for filename in filenames:
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext.lower() == ".wav":
|
||||
_, ext2 = (os.path.splitext(name))
|
||||
if ext2:
|
||||
continue
|
||||
wav_file = os.path.join(root, filename)
|
||||
elif ext.lower() == ".m4a":
|
||||
# Convert AAC to WAV.
|
||||
aac_file = os.path.join(root, filename)
|
||||
wav_file = aac_file + ".wav"
|
||||
if not gfile.Exists(wav_file):
|
||||
if not decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||
raise RuntimeError("Audio decoding failed.")
|
||||
else:
|
||||
continue
|
||||
speaker_name = root.split(os.path.sep)[-2]
|
||||
if speaker_name not in speaker_id_dict:
|
||||
num = len(speaker_id_dict)
|
||||
speaker_id_dict[speaker_name] = num
|
||||
# wav_filesize = os.path.getsize(wav_file)
|
||||
wav_length = len(sf.read(wav_file)[0])
|
||||
files.append(
|
||||
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)
|
||||
)
|
||||
|
||||
# Write to CSV file which contains four columns:
|
||||
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
|
||||
csv_file_path = os.path.join(output_dir, output_file)
|
||||
df = pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
df.to_csv(csv_file_path, index=False, sep="\t")
|
||||
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
|
||||
|
||||
def processor(directory, subset, force_process):
|
||||
""" download and process """
|
||||
urls = SUBSETS
|
||||
if subset not in urls:
|
||||
raise ValueError(subset, "is not in voxceleb")
|
||||
|
||||
subset_csv = os.path.join(directory, subset + '.csv')
|
||||
if not force_process and os.path.exists(subset_csv):
|
||||
return subset_csv
|
||||
|
||||
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||
logging.info("Preparing subset %s", subset)
|
||||
download_and_extract(directory, subset, urls[subset])
|
||||
convert_audio_and_make_label(
|
||||
directory,
|
||||
subset,
|
||||
directory,
|
||||
subset + ".csv"
|
||||
)
|
||||
logging.info("Finished downloading and processing")
|
||||
return subset_csv
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python prepare_data.py save_directory user password")
|
||||
sys.exit()
|
||||
|
||||
DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
for SUBSET in SUBSETS:
|
||||
processor(DIR, SUBSET, False)
|
|
@ -0,0 +1,46 @@
|
|||
import umap
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
||||
colormap = (
|
||||
np.array(
|
||||
[
|
||||
[76, 255, 0],
|
||||
[0, 127, 70],
|
||||
[255, 0, 0],
|
||||
[255, 217, 38],
|
||||
[0, 135, 255],
|
||||
[165, 0, 165],
|
||||
[255, 167, 255],
|
||||
[0, 255, 255],
|
||||
[255, 96, 38],
|
||||
[142, 76, 0],
|
||||
[33, 0, 127],
|
||||
[0, 0, 0],
|
||||
[183, 183, 183],
|
||||
],
|
||||
dtype=np.float,
|
||||
)
|
||||
/ 255
|
||||
)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings, num_utter_per_speaker):
|
||||
embeddings = embeddings[: 10 * num_utter_per_speaker]
|
||||
model = umap.UMAP()
|
||||
projection = model.fit_transform(embeddings)
|
||||
num_speakers = embeddings.shape[0] // num_utter_per_speaker
|
||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||
plt.gca().set_aspect("equal", "datalim")
|
||||
plt.title("UMAP projection")
|
||||
plt.tight_layout()
|
||||
plt.savefig("umap")
|
||||
return fig
|
|
@ -1,374 +0,0 @@
|
|||
import os
|
||||
import glob
|
||||
import torch
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import importlib
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n")
|
||||
if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
return current
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
# try:
|
||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
# 'HEAD']) # Verify client is clean
|
||||
# except:
|
||||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(
|
||||
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except subprocess.CalledProcessError:
|
||||
commit = "0000000"
|
||||
print(' > Git Hash: {}'.format(commit))
|
||||
return commit
|
||||
|
||||
|
||||
def create_experiment_folder(root_path, model_name, debug):
|
||||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
if debug:
|
||||
commit_hash = 'debug'
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(
|
||||
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||
if not checkpoint_files:
|
||||
if os.path.exists(experiment_path):
|
||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
is_multi_speaker = False
|
||||
speakers = [item[-1] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
||||
len(items) * 0.01)
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
# most stupid code ever -- Fix it !
|
||||
while len(items_eval) < eval_split_size:
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
if speaker_counter[items[item_idx][-1]] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
del items[item_idx]
|
||||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint_state.items() if k in model_dict
|
||||
}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
# 3. skip reinit layers
|
||||
if c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if reinit_layer_name not in k
|
||||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||
len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
return model
|
||||
|
||||
class KeepAverage():
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.avg_values[key]
|
||||
|
||||
def items(self):
|
||||
return self.avg_values.items()
|
||||
|
||||
def add_value(self, name, init_val=0, init_iter=0):
|
||||
self.avg_values[name] = init_val
|
||||
self.iters[name] = init_iter
|
||||
|
||||
def update_value(self, name, value, weighted_avg=False):
|
||||
if name not in self.avg_values:
|
||||
# add value if not exist before
|
||||
self.add_value(name, init_val=value)
|
||||
else:
|
||||
# else update existing value
|
||||
if weighted_avg:
|
||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * \
|
||||
self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
def add_values(self, name_dict):
|
||||
for key, value in name_dict.items():
|
||||
self.add_value(key, init_val=value)
|
||||
|
||||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if restricted:
|
||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||
if name in c.keys():
|
||||
if max_val:
|
||||
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
||||
if min_val:
|
||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||
if enum_list:
|
||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||
if val_type:
|
||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
|
||||
|
||||
def check_config(c):
|
||||
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
_check_argument('run_name', c, restricted=True, val_type=str)
|
||||
_check_argument('run_description', c, val_type=str)
|
||||
|
||||
# AUDIO
|
||||
_check_argument('audio', c, restricted=True, val_type=dict)
|
||||
|
||||
# audio processing parameters
|
||||
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
_check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# vocabulary parameters
|
||||
_check_argument('characters', c, restricted=False, val_type=dict)
|
||||
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
|
||||
# normalization parameters
|
||||
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
_check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100)
|
||||
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
_check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# validation parameters
|
||||
_check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
||||
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
||||
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
||||
|
||||
# attention
|
||||
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
||||
_check_argument('attention_heads', c, restricted=True, val_type=int)
|
||||
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
_check_argument('windowing', c, restricted=True, val_type=bool)
|
||||
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# stopnet
|
||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
||||
|
||||
# tensorboard
|
||||
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from TTS.utils.text import cleaners
|
||||
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
|
||||
# paths
|
||||
_check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker
|
||||
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
|
||||
# GST
|
||||
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||
_check_argument('gst', c, restricted=True, val_type=dict)
|
||||
_check_argument('gst_style_input', c['gst'], restricted=True, val_type=str)
|
||||
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
_check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
_check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
Loading…
Reference in New Issue