mirror of https://github.com/coqui-ai/TTS.git
format utils
This commit is contained in:
parent
4873601694
commit
a2606fbc22
|
@ -0,0 +1,118 @@
|
||||||
|
import datetime
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||||
|
from TTS.utils.generic_utils import check_argument
|
||||||
|
|
||||||
|
|
||||||
|
def to_camel(text):
|
||||||
|
text = text.capitalize()
|
||||||
|
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(c):
|
||||||
|
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
|
||||||
|
c.model['lstm_dim'], c.model['num_lstm_layers'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||||
|
current_step, epoch):
|
||||||
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
|
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||||
|
|
||||||
|
new_state_dict = model.state_dict()
|
||||||
|
state = {
|
||||||
|
'model': new_state_dict,
|
||||||
|
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
|
}
|
||||||
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
|
current_step):
|
||||||
|
if model_loss < best_loss:
|
||||||
|
new_state_dict = model.state_dict()
|
||||||
|
state = {
|
||||||
|
'model': new_state_dict,
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'step': current_step,
|
||||||
|
'loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
|
}
|
||||||
|
best_loss = model_loss
|
||||||
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
|
||||||
|
model_loss, bestmodel_path))
|
||||||
|
torch.save(state, bestmodel_path)
|
||||||
|
return best_loss
|
||||||
|
|
||||||
|
|
||||||
|
def check_config_speaker_encoder(c):
|
||||||
|
"""Check the config.json file of the speaker encoder"""
|
||||||
|
check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
|
check_argument('run_description', c, val_type=str)
|
||||||
|
|
||||||
|
# audio processing parameters
|
||||||
|
check_argument('audio', c, restricted=True, val_type=dict)
|
||||||
|
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||||
|
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||||
|
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||||
|
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||||
|
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||||
|
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||||
|
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||||
|
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||||
|
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||||
|
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||||
|
|
||||||
|
# training parameters
|
||||||
|
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
|
||||||
|
check_argument('grad_clip', c, restricted=True, val_type=float)
|
||||||
|
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||||
|
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
check_argument('lr_decay', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||||
|
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
|
||||||
|
check_argument('num_loader_workers', c, restricted=True, val_type=int)
|
||||||
|
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||||
|
|
||||||
|
# checkpoint and output parameters
|
||||||
|
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
|
||||||
|
check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('save_step', c, restricted=True, val_type=int)
|
||||||
|
check_argument('print_step', c, restricted=True, val_type=int)
|
||||||
|
check_argument('output_path', c, restricted=True, val_type=str)
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
check_argument('model', c, restricted=True, val_type=dict)
|
||||||
|
check_argument('input_dim', c['model'], restricted=True, val_type=int)
|
||||||
|
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
|
||||||
|
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
|
||||||
|
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
|
||||||
|
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
|
||||||
|
|
||||||
|
# in-memory storage parameters
|
||||||
|
check_argument('storage', c, restricted=True, val_type=dict)
|
||||||
|
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||||
|
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
|
||||||
|
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||||
|
|
||||||
|
# datasets - checking only the first entry
|
||||||
|
check_argument('datasets', c, restricted=True, val_type=list)
|
||||||
|
for dataset_entry in c['datasets']:
|
||||||
|
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||||
|
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||||
|
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||||
|
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||||
|
|
|
@ -0,0 +1,233 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# Only support eager mode and TF>=2.0.0
|
||||||
|
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
|
||||||
|
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
|
||||||
|
''' voxceleb 1 & 2 '''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
import subprocess
|
||||||
|
import hashlib
|
||||||
|
import pandas
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
gfile = tf.compat.v1.gfile
|
||||||
|
|
||||||
|
SUBSETS = {
|
||||||
|
"vox1_dev_wav":
|
||||||
|
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"],
|
||||||
|
"vox1_test_wav":
|
||||||
|
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
|
||||||
|
"vox2_dev_aac":
|
||||||
|
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
|
||||||
|
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"],
|
||||||
|
"vox2_test_aac":
|
||||||
|
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"]
|
||||||
|
}
|
||||||
|
|
||||||
|
MD5SUM = {
|
||||||
|
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
|
||||||
|
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
|
||||||
|
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
|
||||||
|
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312"
|
||||||
|
}
|
||||||
|
|
||||||
|
USER = {
|
||||||
|
"user": "",
|
||||||
|
"password": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
speaker_id_dict = {}
|
||||||
|
|
||||||
|
def download_and_extract(directory, subset, urls):
|
||||||
|
"""Download and extract the given split of dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: the directory where to put the downloaded data.
|
||||||
|
subset: subset name of the corpus.
|
||||||
|
urls: the list of urls to download the data file.
|
||||||
|
"""
|
||||||
|
if not gfile.Exists(directory):
|
||||||
|
gfile.MakeDirs(directory)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for url in urls:
|
||||||
|
zip_filepath = os.path.join(directory, url.split("/")[-1])
|
||||||
|
if os.path.exists(zip_filepath):
|
||||||
|
continue
|
||||||
|
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||||
|
subprocess.call('wget %s --user %s --password %s -O %s' %
|
||||||
|
(url, USER["user"], USER["password"], zip_filepath), shell=True)
|
||||||
|
|
||||||
|
statinfo = os.stat(zip_filepath)
|
||||||
|
logging.info(
|
||||||
|
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# concatenate all parts into zip files
|
||||||
|
if ".zip" not in zip_filepath:
|
||||||
|
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
|
||||||
|
subprocess.call('cat %s* > %s.zip' %
|
||||||
|
(zip_filepath, zip_filepath), shell=True)
|
||||||
|
zip_filepath += ".zip"
|
||||||
|
extract_path = zip_filepath.strip(".zip")
|
||||||
|
|
||||||
|
# check zip file md5sum
|
||||||
|
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest()
|
||||||
|
if md5 != MD5SUM[subset]:
|
||||||
|
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_filepath, "r") as zfile:
|
||||||
|
zfile.extractall(directory)
|
||||||
|
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
|
||||||
|
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True)
|
||||||
|
finally:
|
||||||
|
# gfile.Remove(zip_filepath)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def exec_cmd(cmd):
|
||||||
|
"""Run a command in a subprocess.
|
||||||
|
Args:
|
||||||
|
cmd: command line to be executed.
|
||||||
|
Return:
|
||||||
|
int, the return code.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
retcode = subprocess.call(cmd, shell=True)
|
||||||
|
if retcode < 0:
|
||||||
|
logging.info(f"Child was terminated by signal {retcode}")
|
||||||
|
except OSError as e:
|
||||||
|
logging.info(f"Execution failed: {e}")
|
||||||
|
retcode = -999
|
||||||
|
return retcode
|
||||||
|
|
||||||
|
|
||||||
|
def decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||||
|
"""Decode a given AAC file into WAV using ffmpeg.
|
||||||
|
Args:
|
||||||
|
aac_file: file path to input AAC file.
|
||||||
|
wav_file: file path to output WAV file.
|
||||||
|
Return:
|
||||||
|
bool, True if success.
|
||||||
|
"""
|
||||||
|
cmd = f"ffmpeg -i {aac_file} {wav_file}"
|
||||||
|
logging.info(f"Decoding aac file using command line: {cmd}")
|
||||||
|
ret = exec_cmd(cmd)
|
||||||
|
if ret != 0:
|
||||||
|
logging.error(f"Failed to decode aac file with retcode {ret}")
|
||||||
|
logging.error("Please check your ffmpeg installation.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def convert_audio_and_make_label(input_dir, subset,
|
||||||
|
output_dir, output_file):
|
||||||
|
"""Optionally convert AAC to WAV and make speaker labels.
|
||||||
|
Args:
|
||||||
|
input_dir: the directory which holds the input dataset.
|
||||||
|
subset: the name of the specified subset. e.g. vox1_dev_wav
|
||||||
|
output_dir: the directory to place the newly generated csv files.
|
||||||
|
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging.info("Preprocessing audio and label for subset %s" % subset)
|
||||||
|
source_dir = os.path.join(input_dir, subset)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
# Convert all AAC file into WAV format. At the same time, generate the csv
|
||||||
|
for root, _, filenames in gfile.Walk(source_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
name, ext = os.path.splitext(filename)
|
||||||
|
if ext.lower() == ".wav":
|
||||||
|
_, ext2 = (os.path.splitext(name))
|
||||||
|
if ext2:
|
||||||
|
continue
|
||||||
|
wav_file = os.path.join(root, filename)
|
||||||
|
elif ext.lower() == ".m4a":
|
||||||
|
# Convert AAC to WAV.
|
||||||
|
aac_file = os.path.join(root, filename)
|
||||||
|
wav_file = aac_file + ".wav"
|
||||||
|
if not gfile.Exists(wav_file):
|
||||||
|
if not decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||||
|
raise RuntimeError("Audio decoding failed.")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
speaker_name = root.split(os.path.sep)[-2]
|
||||||
|
if speaker_name not in speaker_id_dict:
|
||||||
|
num = len(speaker_id_dict)
|
||||||
|
speaker_id_dict[speaker_name] = num
|
||||||
|
# wav_filesize = os.path.getsize(wav_file)
|
||||||
|
wav_length = len(sf.read(wav_file)[0])
|
||||||
|
files.append(
|
||||||
|
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write to CSV file which contains four columns:
|
||||||
|
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
|
||||||
|
csv_file_path = os.path.join(output_dir, output_file)
|
||||||
|
df = pandas.DataFrame(
|
||||||
|
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||||
|
df.to_csv(csv_file_path, index=False, sep="\t")
|
||||||
|
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||||
|
|
||||||
|
|
||||||
|
def processor(directory, subset, force_process):
|
||||||
|
""" download and process """
|
||||||
|
urls = SUBSETS
|
||||||
|
if subset not in urls:
|
||||||
|
raise ValueError(subset, "is not in voxceleb")
|
||||||
|
|
||||||
|
subset_csv = os.path.join(directory, subset + '.csv')
|
||||||
|
if not force_process and os.path.exists(subset_csv):
|
||||||
|
return subset_csv
|
||||||
|
|
||||||
|
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||||
|
logging.info("Preparing subset %s", subset)
|
||||||
|
download_and_extract(directory, subset, urls[subset])
|
||||||
|
convert_audio_and_make_label(
|
||||||
|
directory,
|
||||||
|
subset,
|
||||||
|
directory,
|
||||||
|
subset + ".csv"
|
||||||
|
)
|
||||||
|
logging.info("Finished downloading and processing")
|
||||||
|
return subset_csv
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
print("Usage: python prepare_data.py save_directory user password")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||||
|
for SUBSET in SUBSETS:
|
||||||
|
processor(DIR, SUBSET, False)
|
|
@ -0,0 +1,46 @@
|
||||||
|
import umap
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
|
||||||
|
colormap = (
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[76, 255, 0],
|
||||||
|
[0, 127, 70],
|
||||||
|
[255, 0, 0],
|
||||||
|
[255, 217, 38],
|
||||||
|
[0, 135, 255],
|
||||||
|
[165, 0, 165],
|
||||||
|
[255, 167, 255],
|
||||||
|
[0, 255, 255],
|
||||||
|
[255, 96, 38],
|
||||||
|
[142, 76, 0],
|
||||||
|
[33, 0, 127],
|
||||||
|
[0, 0, 0],
|
||||||
|
[183, 183, 183],
|
||||||
|
],
|
||||||
|
dtype=np.float,
|
||||||
|
)
|
||||||
|
/ 255
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_embeddings(embeddings, num_utter_per_speaker):
|
||||||
|
embeddings = embeddings[: 10 * num_utter_per_speaker]
|
||||||
|
model = umap.UMAP()
|
||||||
|
projection = model.fit_transform(embeddings)
|
||||||
|
num_speakers = embeddings.shape[0] // num_utter_per_speaker
|
||||||
|
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
|
||||||
|
colors = [colormap[i] for i in ground_truth]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(16, 10))
|
||||||
|
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||||
|
plt.gca().set_aspect("equal", "datalim")
|
||||||
|
plt.title("UMAP projection")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("umap")
|
||||||
|
return fig
|
|
@ -1,374 +0,0 @@
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import torch
|
|
||||||
import shutil
|
|
||||||
import datetime
|
|
||||||
import subprocess
|
|
||||||
import importlib
|
|
||||||
import numpy as np
|
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
|
||||||
try:
|
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
|
||||||
current = next(line for line in out.split("\n")
|
|
||||||
if line.startswith("*"))
|
|
||||||
current.replace("* ", "")
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
current = "inside_docker"
|
|
||||||
return current
|
|
||||||
|
|
||||||
|
|
||||||
def get_commit_hash():
|
|
||||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
|
||||||
# try:
|
|
||||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
|
||||||
# 'HEAD']) # Verify client is clean
|
|
||||||
# except:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# " !! Commit before training to get the commit hash.")
|
|
||||||
try:
|
|
||||||
commit = subprocess.check_output(
|
|
||||||
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
|
||||||
# Not copying .git folder into docker container
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
commit = "0000000"
|
|
||||||
print(' > Git Hash: {}'.format(commit))
|
|
||||||
return commit
|
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path, model_name, debug):
|
|
||||||
""" Create a folder with the current date and time """
|
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
|
||||||
if debug:
|
|
||||||
commit_hash = 'debug'
|
|
||||||
else:
|
|
||||||
commit_hash = get_commit_hash()
|
|
||||||
output_folder = os.path.join(
|
|
||||||
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
|
||||||
return output_folder
|
|
||||||
|
|
||||||
|
|
||||||
def remove_experiment_folder(experiment_path):
|
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
|
||||||
if not checkpoint_files:
|
|
||||||
if os.path.exists(experiment_path):
|
|
||||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
|
||||||
else:
|
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
|
||||||
r"""Count number of trainable parameters in a network"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(items):
|
|
||||||
is_multi_speaker = False
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
|
||||||
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
|
||||||
len(items) * 0.01)
|
|
||||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
|
||||||
np.random.seed(0)
|
|
||||||
np.random.shuffle(items)
|
|
||||||
if is_multi_speaker:
|
|
||||||
items_eval = []
|
|
||||||
# most stupid code ever -- Fix it !
|
|
||||||
while len(items_eval) < eval_split_size:
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
speaker_counter = Counter(speakers)
|
|
||||||
item_idx = np.random.randint(0, len(items))
|
|
||||||
if speaker_counter[items[item_idx][-1]] > 1:
|
|
||||||
items_eval.append(items[item_idx])
|
|
||||||
del items[item_idx]
|
|
||||||
return items_eval, items
|
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
|
||||||
def sequence_mask(sequence_length, max_len=None):
|
|
||||||
if max_len is None:
|
|
||||||
max_len = sequence_length.data.max()
|
|
||||||
batch_size = sequence_length.size(0)
|
|
||||||
seq_range = torch.arange(0, max_len).long()
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
||||||
if sequence_length.is_cuda:
|
|
||||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
|
||||||
seq_length_expand = (
|
|
||||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
|
||||||
# B x T_max
|
|
||||||
return seq_range_expand < seq_length_expand
|
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint_state, c):
|
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
|
||||||
for k, v in checkpoint_state.items():
|
|
||||||
if k not in model_dict:
|
|
||||||
print(" | > Layer missing in the model definition: {}".format(k))
|
|
||||||
# 1. filter out unnecessary keys
|
|
||||||
pretrained_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in checkpoint_state.items() if k in model_dict
|
|
||||||
}
|
|
||||||
# 2. filter out different size layers
|
|
||||||
pretrained_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in pretrained_dict.items()
|
|
||||||
if v.numel() == model_dict[k].numel()
|
|
||||||
}
|
|
||||||
# 3. skip reinit layers
|
|
||||||
if c.reinit_layers is not None:
|
|
||||||
for reinit_layer_name in c.reinit_layers:
|
|
||||||
pretrained_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in pretrained_dict.items()
|
|
||||||
if reinit_layer_name not in k
|
|
||||||
}
|
|
||||||
# 4. overwrite entries in the existing state dict
|
|
||||||
model_dict.update(pretrained_dict)
|
|
||||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
|
||||||
len(model_dict)))
|
|
||||||
return model_dict
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, num_speakers, c):
|
|
||||||
print(" > Using model: {}".format(c.model))
|
|
||||||
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
|
||||||
MyModel = getattr(MyModel, c.model)
|
|
||||||
if c.model.lower() in "tacotron":
|
|
||||||
model = MyModel(num_chars=num_chars,
|
|
||||||
num_speakers=num_speakers,
|
|
||||||
r=c.r,
|
|
||||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
|
||||||
gst=c.use_gst,
|
|
||||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
|
||||||
gst_num_heads=c.gst['gst_num_heads'],
|
|
||||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
|
||||||
memory_size=c.memory_size,
|
|
||||||
attn_type=c.attention_type,
|
|
||||||
attn_win=c.windowing,
|
|
||||||
attn_norm=c.attention_norm,
|
|
||||||
prenet_type=c.prenet_type,
|
|
||||||
prenet_dropout=c.prenet_dropout,
|
|
||||||
forward_attn=c.use_forward_attn,
|
|
||||||
trans_agent=c.transition_agent,
|
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
|
||||||
location_attn=c.location_attn,
|
|
||||||
attn_K=c.attention_heads,
|
|
||||||
separate_stopnet=c.separate_stopnet,
|
|
||||||
bidirectional_decoder=c.bidirectional_decoder,
|
|
||||||
double_decoder_consistency=c.double_decoder_consistency,
|
|
||||||
ddc_r=c.ddc_r)
|
|
||||||
elif c.model.lower() == "tacotron2":
|
|
||||||
model = MyModel(num_chars=num_chars,
|
|
||||||
num_speakers=num_speakers,
|
|
||||||
r=c.r,
|
|
||||||
postnet_output_dim=c.audio['num_mels'],
|
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
|
||||||
gst=c.use_gst,
|
|
||||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
|
||||||
gst_num_heads=c.gst['gst_num_heads'],
|
|
||||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
|
||||||
attn_type=c.attention_type,
|
|
||||||
attn_win=c.windowing,
|
|
||||||
attn_norm=c.attention_norm,
|
|
||||||
prenet_type=c.prenet_type,
|
|
||||||
prenet_dropout=c.prenet_dropout,
|
|
||||||
forward_attn=c.use_forward_attn,
|
|
||||||
trans_agent=c.transition_agent,
|
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
|
||||||
location_attn=c.location_attn,
|
|
||||||
attn_K=c.attention_heads,
|
|
||||||
separate_stopnet=c.separate_stopnet,
|
|
||||||
bidirectional_decoder=c.bidirectional_decoder,
|
|
||||||
double_decoder_consistency=c.double_decoder_consistency,
|
|
||||||
ddc_r=c.ddc_r)
|
|
||||||
return model
|
|
||||||
|
|
||||||
class KeepAverage():
|
|
||||||
def __init__(self):
|
|
||||||
self.avg_values = {}
|
|
||||||
self.iters = {}
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.avg_values[key]
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.avg_values.items()
|
|
||||||
|
|
||||||
def add_value(self, name, init_val=0, init_iter=0):
|
|
||||||
self.avg_values[name] = init_val
|
|
||||||
self.iters[name] = init_iter
|
|
||||||
|
|
||||||
def update_value(self, name, value, weighted_avg=False):
|
|
||||||
if name not in self.avg_values:
|
|
||||||
# add value if not exist before
|
|
||||||
self.add_value(name, init_val=value)
|
|
||||||
else:
|
|
||||||
# else update existing value
|
|
||||||
if weighted_avg:
|
|
||||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
|
||||||
self.iters[name] += 1
|
|
||||||
else:
|
|
||||||
self.avg_values[name] = self.avg_values[name] * \
|
|
||||||
self.iters[name] + value
|
|
||||||
self.iters[name] += 1
|
|
||||||
self.avg_values[name] /= self.iters[name]
|
|
||||||
|
|
||||||
def add_values(self, name_dict):
|
|
||||||
for key, value in name_dict.items():
|
|
||||||
self.add_value(key, init_val=value)
|
|
||||||
|
|
||||||
def update_values(self, value_dict):
|
|
||||||
for key, value in value_dict.items():
|
|
||||||
self.update_value(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
|
||||||
if alternative in c.keys() and c[alternative] is not None:
|
|
||||||
return
|
|
||||||
if restricted:
|
|
||||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
|
||||||
if name in c.keys():
|
|
||||||
if max_val:
|
|
||||||
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
|
||||||
if min_val:
|
|
||||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
|
||||||
if enum_list:
|
|
||||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
|
||||||
if val_type:
|
|
||||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
|
||||||
|
|
||||||
|
|
||||||
def check_config(c):
|
|
||||||
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
|
||||||
_check_argument('run_name', c, restricted=True, val_type=str)
|
|
||||||
_check_argument('run_description', c, val_type=str)
|
|
||||||
|
|
||||||
# AUDIO
|
|
||||||
_check_argument('audio', c, restricted=True, val_type=dict)
|
|
||||||
|
|
||||||
# audio processing parameters
|
|
||||||
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
|
||||||
_check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
|
||||||
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
|
||||||
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
|
||||||
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
|
||||||
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
|
||||||
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
|
||||||
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
|
||||||
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
|
||||||
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
|
||||||
|
|
||||||
# vocabulary parameters
|
|
||||||
_check_argument('characters', c, restricted=False, val_type=dict)
|
|
||||||
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
|
||||||
|
|
||||||
# normalization parameters
|
|
||||||
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
|
||||||
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
|
||||||
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
|
||||||
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
|
||||||
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
|
||||||
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
|
||||||
_check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100)
|
|
||||||
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
|
||||||
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
|
||||||
|
|
||||||
# training parameters
|
|
||||||
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('gradual_training', c, restricted=False, val_type=list)
|
|
||||||
_check_argument('loss_masking', c, restricted=True, val_type=bool)
|
|
||||||
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
|
||||||
|
|
||||||
# validation parameters
|
|
||||||
_check_argument('run_eval', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
|
||||||
|
|
||||||
# optimizer
|
|
||||||
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
|
||||||
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
|
||||||
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
|
||||||
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
|
|
||||||
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
|
||||||
|
|
||||||
# tacotron prenet
|
|
||||||
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
|
||||||
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
|
||||||
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
|
||||||
_check_argument('attention_heads', c, restricted=True, val_type=int)
|
|
||||||
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
|
||||||
_check_argument('windowing', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
|
||||||
|
|
||||||
# stopnet
|
|
||||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
|
||||||
|
|
||||||
# tensorboard
|
|
||||||
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
|
||||||
|
|
||||||
# dataloading
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
from TTS.utils.text import cleaners
|
|
||||||
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
|
||||||
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
|
||||||
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
|
||||||
|
|
||||||
# paths
|
|
||||||
_check_argument('output_path', c, restricted=True, val_type=str)
|
|
||||||
|
|
||||||
# multi-speaker
|
|
||||||
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
|
||||||
|
|
||||||
# GST
|
|
||||||
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
|
||||||
_check_argument('gst', c, restricted=True, val_type=dict)
|
|
||||||
_check_argument('gst_style_input', c['gst'], restricted=True, val_type=str)
|
|
||||||
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
|
||||||
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
|
||||||
|
|
||||||
# datasets - checking only the first entry
|
|
||||||
_check_argument('datasets', c, restricted=True, val_type=list)
|
|
||||||
for dataset_entry in c['datasets']:
|
|
||||||
_check_argument('name', dataset_entry, restricted=True, val_type=str)
|
|
||||||
_check_argument('path', dataset_entry, restricted=True, val_type=str)
|
|
||||||
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
|
||||||
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
|
Loading…
Reference in New Issue