mirror of https://github.com/coqui-ai/TTS.git
refactoring utils
This commit is contained in:
parent
720c4690db
commit
574968b249
|
@ -1,31 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import glob
|
import glob
|
||||||
|
import torch
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict, Counter
|
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(AttrDict, self).__init__(*args, **kwargs)
|
|
||||||
self.__dict__ = self
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path):
|
|
||||||
config = AttrDict()
|
|
||||||
with open(config_path, "r") as f:
|
|
||||||
input_str = f.read()
|
|
||||||
input_str = re.sub(r'\\\n', '', input_str)
|
|
||||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
|
||||||
data = json.loads(input_str)
|
|
||||||
config.update(data)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
|
@ -83,155 +63,34 @@ def remove_experiment_folder(experiment_path):
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
||||||
|
|
||||||
def copy_config_file(config_file, out_path, new_fields):
|
|
||||||
config_lines = open(config_file, "r").readlines()
|
|
||||||
# add extra information fields
|
|
||||||
for key, value in new_fields.items():
|
|
||||||
if type(value) == str:
|
|
||||||
new_line = '"{}":"{}",\n'.format(key, value)
|
|
||||||
else:
|
|
||||||
new_line = '"{}":{},\n'.format(key, value)
|
|
||||||
config_lines.insert(1, new_line)
|
|
||||||
config_out_file = open(out_path, "w")
|
|
||||||
config_out_file.writelines(config_lines)
|
|
||||||
config_out_file.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _trim_model_state_dict(state_dict):
|
|
||||||
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
|
|
||||||
is loded for the next time by model.load_state(). Otherwise, it complains
|
|
||||||
about the torch.DataParallel()"""
|
|
||||||
|
|
||||||
new_state_dict = OrderedDict()
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
name = k[7:] # remove `module.`
|
|
||||||
new_state_dict[name] = v
|
|
||||||
return new_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
|
||||||
current_step, epoch):
|
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
|
||||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
|
||||||
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
'model': new_state_dict,
|
|
||||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
|
||||||
'step': current_step,
|
|
||||||
'epoch': epoch,
|
|
||||||
'linear_loss': model_loss,
|
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
'r': model.decoder.r
|
|
||||||
}
|
|
||||||
torch.save(state, checkpoint_path)
|
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|
||||||
current_step, epoch):
|
|
||||||
if model_loss < best_loss:
|
|
||||||
new_state_dict = model.state_dict()
|
|
||||||
state = {
|
|
||||||
'model': new_state_dict,
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'step': current_step,
|
|
||||||
'epoch': epoch,
|
|
||||||
'linear_loss': model_loss,
|
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
|
||||||
'r': model.decoder.r
|
|
||||||
}
|
|
||||||
best_loss = model_loss
|
|
||||||
bestmodel_path = 'best_model.pth.tar'
|
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
|
||||||
print(" > BEST MODEL ({0:.5f}) : {1:}".format(
|
|
||||||
model_loss, bestmodel_path))
|
|
||||||
torch.save(state, bestmodel_path)
|
|
||||||
return best_loss
|
|
||||||
|
|
||||||
|
|
||||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
|
||||||
r'''Check model gradient against unexpected jumps and failures'''
|
|
||||||
skip_flag = False
|
|
||||||
if ignore_stopnet:
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
|
||||||
else:
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
||||||
if np.isinf(grad_norm):
|
|
||||||
print(" | > Gradient is INF !!")
|
|
||||||
skip_flag = True
|
|
||||||
return grad_norm, skip_flag
|
|
||||||
|
|
||||||
|
|
||||||
def lr_decay(init_lr, global_step, warmup_steps):
|
|
||||||
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
|
||||||
warmup_steps = float(warmup_steps)
|
|
||||||
step = global_step + 1.
|
|
||||||
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
|
||||||
step**-0.5)
|
|
||||||
return lr
|
|
||||||
|
|
||||||
|
|
||||||
def adam_weight_decay(optimizer):
|
|
||||||
"""
|
|
||||||
Custom weight decay operation, not effecting grad values.
|
|
||||||
"""
|
|
||||||
for group in optimizer.param_groups:
|
|
||||||
for param in group['params']:
|
|
||||||
current_lr = group['lr']
|
|
||||||
weight_decay = group['weight_decay']
|
|
||||||
param.data = param.data.add(-weight_decay * group['lr'],
|
|
||||||
param.data)
|
|
||||||
return optimizer, current_lr
|
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
|
||||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
|
||||||
"""
|
|
||||||
Skip biases, BatchNorm parameters, rnns.
|
|
||||||
and attention projection layer v
|
|
||||||
"""
|
|
||||||
decay = []
|
|
||||||
no_decay = []
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
|
|
||||||
no_decay.append(param)
|
|
||||||
else:
|
|
||||||
decay.append(param)
|
|
||||||
return [{
|
|
||||||
'params': no_decay,
|
|
||||||
'weight_decay': 0.
|
|
||||||
}, {
|
|
||||||
'params': decay,
|
|
||||||
'weight_decay': weight_decay
|
|
||||||
}]
|
|
||||||
|
|
||||||
|
|
||||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
|
||||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
|
||||||
self.warmup_steps = float(warmup_steps)
|
|
||||||
super(NoamLR, self).__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
step = max(self.last_epoch, 1)
|
|
||||||
return [
|
|
||||||
base_lr * self.warmup_steps**0.5 *
|
|
||||||
min(step * self.warmup_steps**-1.5, step**-0.5)
|
|
||||||
for base_lr in self.base_lrs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def mk_decay(init_mk, max_epoch, n_epoch):
|
|
||||||
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
def count_parameters(model):
|
||||||
r"""Count number of trainable parameters in a network"""
|
r"""Count number of trainable parameters in a network"""
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(items):
|
||||||
|
is_multi_speaker = False
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
|
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
||||||
|
len(items) * 0.01)
|
||||||
|
np.random.seed(0)
|
||||||
|
np.random.shuffle(items)
|
||||||
|
if is_multi_speaker:
|
||||||
|
items_eval = []
|
||||||
|
# most stupid code ever -- Fix it !
|
||||||
|
while len(items_eval) < eval_split_size:
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
speaker_counter = Counter(speakers)
|
||||||
|
item_idx = np.random.randint(0, len(items))
|
||||||
|
if speaker_counter[items[item_idx][-1]] > 1:
|
||||||
|
items_eval.append(items[item_idx])
|
||||||
|
del items[item_idx]
|
||||||
|
return items_eval, items
|
||||||
|
else:
|
||||||
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
def sequence_mask(sequence_length, max_len=None):
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
|
@ -322,44 +181,6 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
bidirectional_decoder=c.bidirectional_decoder)
|
bidirectional_decoder=c.bidirectional_decoder)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(items):
|
|
||||||
is_multi_speaker = False
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
|
||||||
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
|
||||||
len(items) * 0.01)
|
|
||||||
np.random.seed(0)
|
|
||||||
np.random.shuffle(items)
|
|
||||||
if is_multi_speaker:
|
|
||||||
items_eval = []
|
|
||||||
# most stupid code ever -- Fix it !
|
|
||||||
while len(items_eval) < eval_split_size:
|
|
||||||
speakers = [item[-1] for item in items]
|
|
||||||
speaker_counter = Counter(speakers)
|
|
||||||
item_idx = np.random.randint(0, len(items))
|
|
||||||
if speaker_counter[items[item_idx][-1]] > 1:
|
|
||||||
items_eval.append(items[item_idx])
|
|
||||||
del items[item_idx]
|
|
||||||
return items_eval, items
|
|
||||||
else:
|
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
|
||||||
|
|
||||||
|
|
||||||
def gradual_training_scheduler(global_step, config):
|
|
||||||
"""Setup the gradual training schedule wrt number
|
|
||||||
of active GPUs"""
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
if num_gpus == 0:
|
|
||||||
num_gpus = 1
|
|
||||||
new_values = None
|
|
||||||
# we set the scheduling wrt num_gpus
|
|
||||||
for values in config.gradual_training:
|
|
||||||
if global_step * num_gpus >= values[0]:
|
|
||||||
new_values = values
|
|
||||||
return new_values[1], new_values[2]
|
|
||||||
|
|
||||||
|
|
||||||
class KeepAverage():
|
class KeepAverage():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.avg_values = {}
|
self.avg_values = {}
|
||||||
|
@ -410,30 +231,6 @@ def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restric
|
||||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||||
|
|
||||||
|
|
||||||
tcolors = AttrDict({
|
|
||||||
'OKBLUE': '\033[94m',
|
|
||||||
'HEADER': '\033[95m',
|
|
||||||
'OKGREEN': '\033[92m',
|
|
||||||
'WARNING': '\033[93m',
|
|
||||||
'FAIL': '\033[91m',
|
|
||||||
'ENDC': '\033[0m',
|
|
||||||
'BOLD': '\033[1m',
|
|
||||||
'UNDERLINE': '\033[4m'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_length, step_time, loader_time, lr, print_dict):
|
|
||||||
indent = " | > "
|
|
||||||
print()
|
|
||||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
|
||||||
for key, value in print_dict.items():
|
|
||||||
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
|
||||||
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
|
|
||||||
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
|
|
||||||
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
|
|
||||||
print(log_text, flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def check_config(c):
|
def check_config(c):
|
||||||
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||||
_check_argument('run_name', c, restricted=True, val_type=str)
|
_check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
config = AttrDict()
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
input_str = f.read()
|
||||||
|
input_str = re.sub(r'\\\n', '', input_str)
|
||||||
|
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||||
|
data = json.loads(input_str)
|
||||||
|
config.update(data)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def copy_config_file(config_file, out_path, new_fields):
|
||||||
|
config_lines = open(config_file, "r").readlines()
|
||||||
|
# add extra information fields
|
||||||
|
for key, value in new_fields.items():
|
||||||
|
if type(value) == str:
|
||||||
|
new_line = '"{}":"{}",\n'.format(key, value)
|
||||||
|
else:
|
||||||
|
new_line = '"{}":{},\n'.format(key, value)
|
||||||
|
config_lines.insert(1, new_line)
|
||||||
|
config_out_file = open(out_path, "w")
|
||||||
|
config_out_file.writelines(config_lines)
|
||||||
|
config_out_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
model.load_state_dict(state['model'])
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
# set model stepsize
|
||||||
|
if 'r' in state.keys():
|
||||||
|
model.decoder.set_r(state['r'])
|
||||||
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, optimizer, current_step, epoch, r, output_folder, file_name, **kwargs):
|
||||||
|
checkpoint_path = os.path.join(output_folder, file_name)
|
||||||
|
|
||||||
|
new_state_dict = model.state_dict()
|
||||||
|
state = {
|
||||||
|
'model': new_state_dict,
|
||||||
|
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
|
'r': model.decoder.r
|
||||||
|
}
|
||||||
|
state.update(kwargs)
|
||||||
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||||
|
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||||
|
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
|
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||||
|
if target_loss < best_loss:
|
||||||
|
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||||
|
file_name = 'best_model.pth.tar'
|
||||||
|
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, model_loss=target_loss)
|
||||||
|
best_loss = target_loss
|
||||||
|
return best_loss
|
|
@ -0,0 +1,90 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||||
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
|
skip_flag = False
|
||||||
|
if ignore_stopnet:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
|
if torch.isinf(grad_norm):
|
||||||
|
print(" | > Gradient is INF !!")
|
||||||
|
skip_flag = True
|
||||||
|
return grad_norm, skip_flag
|
||||||
|
|
||||||
|
|
||||||
|
def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
|
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
||||||
|
warmup_steps = float(warmup_steps)
|
||||||
|
step = global_step + 1.
|
||||||
|
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||||
|
step**-0.5)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def adam_weight_decay(optimizer):
|
||||||
|
"""
|
||||||
|
Custom weight decay operation, not effecting grad values.
|
||||||
|
"""
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
current_lr = group['lr']
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
|
factor = -weight_decay * group['lr']
|
||||||
|
param.data = param.data.add(param.data,
|
||||||
|
alpha=factor)
|
||||||
|
return optimizer, current_lr
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||||
|
"""
|
||||||
|
Skip biases, BatchNorm parameters, rnns.
|
||||||
|
and attention projection layer v
|
||||||
|
"""
|
||||||
|
decay = []
|
||||||
|
no_decay = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
|
||||||
|
no_decay.append(param)
|
||||||
|
else:
|
||||||
|
decay.append(param)
|
||||||
|
return [{
|
||||||
|
'params': no_decay,
|
||||||
|
'weight_decay': 0.
|
||||||
|
}, {
|
||||||
|
'params': decay,
|
||||||
|
'weight_decay': weight_decay
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||||
|
self.warmup_steps = float(warmup_steps)
|
||||||
|
super(NoamLR, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
step = max(self.last_epoch, 1)
|
||||||
|
return [
|
||||||
|
base_lr * self.warmup_steps**0.5 *
|
||||||
|
min(step * self.warmup_steps**-1.5, step**-0.5)
|
||||||
|
for base_lr in self.base_lrs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def gradual_training_scheduler(global_step, config):
|
||||||
|
"""Setup the gradual training schedule wrt number
|
||||||
|
of active GPUs"""
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if num_gpus == 0:
|
||||||
|
num_gpus = 1
|
||||||
|
new_values = None
|
||||||
|
# we set the scheduling wrt num_gpus
|
||||||
|
for values in config.gradual_training:
|
||||||
|
if global_step * num_gpus >= values[0]:
|
||||||
|
new_values = values
|
||||||
|
return new_values[1], new_values[2]
|
Loading…
Reference in New Issue