refactoring utils

This commit is contained in:
erogol 2020-05-12 13:48:51 +02:00
parent 720c4690db
commit 574968b249
3 changed files with 192 additions and 227 deletions

View File

@ -1,31 +1,11 @@
import os
import re
import glob
import torch
import shutil
import datetime
import json
import torch
import subprocess
import importlib
import numpy as np
from collections import OrderedDict, Counter
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
with open(config_path, "r") as f:
input_str = f.read()
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
config.update(data)
return config
def get_git_branch():
@ -83,155 +63,34 @@ def remove_experiment_folder(experiment_path):
print(" ! Run is kept in {}".format(experiment_path))
def copy_config_file(config_file, out_path, new_fields):
config_lines = open(config_file, "r").readlines()
# add extra information fields
for key, value in new_fields.items():
if type(value) == str:
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, value)
config_lines.insert(1, new_line)
config_out_file = open(out_path, "w")
config_out_file.writelines(config_lines)
config_out_file.close()
def _trim_model_state_dict(state_dict):
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
is loded for the next time by model.load_state(). Otherwise, it complains
about the torch.DataParallel()"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" > CHECKPOINT : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
}
torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path)
print(" > BEST MODEL ({0:.5f}) : {1:}".format(
model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss
def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
warmup_steps = float(warmup_steps)
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
return lr
def adam_weight_decay(optimizer):
"""
Custom weight decay operation, not effecting grad values.
"""
for group in optimizer.param_groups:
for param in group['params']:
current_lr = group['lr']
weight_decay = group['weight_decay']
param.data = param.data.add(-weight_decay * group['lr'],
param.data)
return optimizer, current_lr
# pylint: disable=dangerous-default-value
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
Skip biases, BatchNorm parameters, rnns.
and attention projection layer v
"""
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super(NoamLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 *
min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
def mk_decay(init_mk, max_epoch, n_epoch):
return init_mk * ((max_epoch - n_epoch) / max_epoch)
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
items_eval.append(items[item_idx])
del items[item_idx]
return items_eval, items
else:
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
@ -322,44 +181,6 @@ def setup_model(num_chars, num_speakers, c):
bidirectional_decoder=c.bidirectional_decoder)
return model
def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
items_eval.append(items[item_idx])
del items[item_idx]
return items_eval, items
else:
return items[:eval_split_size], items[eval_split_size:]
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
num_gpus = 1
new_values = None
# we set the scheduling wrt num_gpus
for values in config.gradual_training:
if global_step * num_gpus >= values[0]:
new_values = values
return new_values[1], new_values[2]
class KeepAverage():
def __init__(self):
self.avg_values = {}
@ -410,30 +231,6 @@ def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restric
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
tcolors = AttrDict({
'OKBLUE': '\033[94m',
'HEADER': '\033[95m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m',
'BOLD': '\033[1m',
'UNDERLINE': '\033[4m'
})
def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_length, step_time, loader_time, lr, print_dict):
indent = " | > "
print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
print(log_text, flush=True)
def check_config(c):
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
_check_argument('run_name', c, restricted=True, val_type=str)

78
utils/io.py Normal file
View File

@ -0,0 +1,78 @@
import os
import json
import re
import torch
import datetime
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
with open(config_path, "r") as f:
input_str = f.read()
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
config.update(data)
return config
def copy_config_file(config_file, out_path, new_fields):
config_lines = open(config_file, "r").readlines()
# add extra information fields
for key, value in new_fields.items():
if type(value) == str:
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, value)
config_lines.insert(1, new_line)
config_out_file = open(out_path, "w")
config_out_file.writelines(config_lines)
config_out_file.close()
def load_checkpoint(model, checkpoint_path, use_cuda=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model'])
if use_cuda:
model.cuda()
# set model stepsize
if 'r' in state.keys():
model.decoder.set_r(state['r'])
return model, state
def save_model(model, optimizer, current_step, epoch, r, output_folder, file_name, **kwargs):
checkpoint_path = os.path.join(output_folder, file_name)
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
}
state.update(kwargs)
torch.save(state, checkpoint_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
print(" > CHECKPOINT : {}".format(checkpoint_path))
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
if target_loss < best_loss:
print(" > BEST MODEL : {}".format(checkpoint_path))
file_name = 'best_model.pth.tar'
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, model_loss=target_loss)
best_loss = target_loss
return best_loss

90
utils/training.py Normal file
View File

@ -0,0 +1,90 @@
import torch
import numpy as np
def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if torch.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
warmup_steps = float(warmup_steps)
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
return lr
def adam_weight_decay(optimizer):
"""
Custom weight decay operation, not effecting grad values.
"""
for group in optimizer.param_groups:
for param in group['params']:
current_lr = group['lr']
weight_decay = group['weight_decay']
factor = -weight_decay * group['lr']
param.data = param.data.add(param.data,
alpha=factor)
return optimizer, current_lr
# pylint: disable=dangerous-default-value
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
Skip biases, BatchNorm parameters, rnns.
and attention projection layer v
"""
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super(NoamLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 *
min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
num_gpus = 1
new_values = None
# we set the scheduling wrt num_gpus
for values in config.gradual_training:
if global_step * num_gpus >= values[0]:
new_values = values
return new_values[1], new_values[2]