mirror of https://github.com/coqui-ai/TTS.git
172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
import datetime
|
|
import glob
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import contextlib
|
|
import platform
|
|
|
|
import torch
|
|
|
|
|
|
def set_amp_context(mixed_precision):
|
|
if mixed_precision:
|
|
cm = torch.cuda.amp.autocast()
|
|
else:
|
|
if platform.python_version() <= "3.6.0":
|
|
cm = contextlib.suppress()
|
|
else:
|
|
cm = contextlib.nullcontext()
|
|
return cm
|
|
|
|
|
|
def get_git_branch():
|
|
try:
|
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
|
current = next(line for line in out.split("\n")
|
|
if line.startswith("*"))
|
|
current.replace("* ", "")
|
|
except subprocess.CalledProcessError:
|
|
current = "inside_docker"
|
|
return current
|
|
|
|
|
|
def get_commit_hash():
|
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
|
# try:
|
|
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
|
# 'HEAD']) # Verify client is clean
|
|
# except:
|
|
# raise RuntimeError(
|
|
# " !! Commit before training to get the commit hash.")
|
|
try:
|
|
commit = subprocess.check_output(
|
|
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
|
# Not copying .git folder into docker container
|
|
except subprocess.CalledProcessError:
|
|
commit = "0000000"
|
|
print(' > Git Hash: {}'.format(commit))
|
|
return commit
|
|
|
|
|
|
def create_experiment_folder(root_path, model_name, debug):
|
|
""" Create a folder with the current date and time """
|
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
|
if debug:
|
|
commit_hash = 'debug'
|
|
else:
|
|
commit_hash = get_commit_hash()
|
|
output_folder = os.path.join(
|
|
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
print(" > Experiment folder: {}".format(output_folder))
|
|
return output_folder
|
|
|
|
|
|
def remove_experiment_folder(experiment_path):
|
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
|
|
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
|
if not checkpoint_files:
|
|
if os.path.exists(experiment_path):
|
|
shutil.rmtree(experiment_path, ignore_errors=True)
|
|
print(" ! Run is removed from {}".format(experiment_path))
|
|
else:
|
|
print(" ! Run is kept in {}".format(experiment_path))
|
|
|
|
|
|
def count_parameters(model):
|
|
r"""Count number of trainable parameters in a network"""
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
def set_init_dict(model_dict, checkpoint_state, c):
|
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
|
for k, v in checkpoint_state.items():
|
|
if k not in model_dict:
|
|
print(" | > Layer missing in the model definition: {}".format(k))
|
|
# 1. filter out unnecessary keys
|
|
pretrained_dict = {
|
|
k: v
|
|
for k, v in checkpoint_state.items() if k in model_dict
|
|
}
|
|
# 2. filter out different size layers
|
|
pretrained_dict = {
|
|
k: v
|
|
for k, v in pretrained_dict.items()
|
|
if v.numel() == model_dict[k].numel()
|
|
}
|
|
# 3. skip reinit layers
|
|
if c.reinit_layers is not None:
|
|
for reinit_layer_name in c.reinit_layers:
|
|
pretrained_dict = {
|
|
k: v
|
|
for k, v in pretrained_dict.items()
|
|
if reinit_layer_name not in k
|
|
}
|
|
# 4. overwrite entries in the existing state dict
|
|
model_dict.update(pretrained_dict)
|
|
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
|
len(model_dict)))
|
|
return model_dict
|
|
|
|
class KeepAverage():
|
|
def __init__(self):
|
|
self.avg_values = {}
|
|
self.iters = {}
|
|
|
|
def __getitem__(self, key):
|
|
return self.avg_values[key]
|
|
|
|
def items(self):
|
|
return self.avg_values.items()
|
|
|
|
def add_value(self, name, init_val=0, init_iter=0):
|
|
self.avg_values[name] = init_val
|
|
self.iters[name] = init_iter
|
|
|
|
def update_value(self, name, value, weighted_avg=False):
|
|
if name not in self.avg_values:
|
|
# add value if not exist before
|
|
self.add_value(name, init_val=value)
|
|
else:
|
|
# else update existing value
|
|
if weighted_avg:
|
|
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
|
self.iters[name] += 1
|
|
else:
|
|
self.avg_values[name] = self.avg_values[name] * \
|
|
self.iters[name] + value
|
|
self.iters[name] += 1
|
|
self.avg_values[name] /= self.iters[name]
|
|
|
|
def add_values(self, name_dict):
|
|
for key, value in name_dict.items():
|
|
self.add_value(key, init_val=value)
|
|
|
|
def update_values(self, value_dict):
|
|
for key, value in value_dict.items():
|
|
self.update_value(key, value)
|
|
|
|
|
|
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
|
if alternative in c.keys() and c[alternative] is not None:
|
|
return
|
|
if restricted:
|
|
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
|
if name in c.keys():
|
|
if max_val:
|
|
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
|
if min_val:
|
|
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
|
if enum_list:
|
|
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
|
if isinstance(val_type, list):
|
|
is_valid = False
|
|
for typ in val_type:
|
|
if isinstance(c[name], typ):
|
|
is_valid = True
|
|
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
|
elif val_type:
|
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|