skip weight decay for BN and biases, some formatting

This commit is contained in:
Eren Golge 2019-09-28 01:09:28 +02:00
parent 53d658fb74
commit b76aaf8ad4
2 changed files with 72 additions and 45 deletions

View File

@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, remove_experiment_folder, load_config, remove_experiment_folder,
save_best_model, save_checkpoint, weight_decay, save_best_model, save_checkpoint, weight_decay,
set_init_dict, copy_config_file, setup_model, set_init_dict, copy_config_file, setup_model,
split_dataset, gradual_training_scheduler, KeepAverage) split_dataset, gradual_training_scheduler, KeepAverage,
set_weight_decay)
from TTS.utils.logger import Logger from TTS.utils.logger import Logger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers get_speakers
@ -186,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loss += stop_loss loss += stop_loss
loss.backward() loss.backward()
optimizer, current_lr = weight_decay(optimizer, c.wd) optimizer, current_lr = weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
@ -197,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# backpass and check the grad norm for stop loss # backpass and check the grad norm for stop loss
if c.separate_stopnet: if c.separate_stopnet:
stop_loss.backward() stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd) optimizer_st, _ = weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step() optimizer_st.step()
else: else:
@ -511,7 +512,8 @@ def main(args): # pylint: disable=redefined-outer-name
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0) params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet: if c.stopnet and c.separate_stopnet:
optimizer_st = RAdam( optimizer_st = RAdam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)

View File

@ -31,8 +31,8 @@ def load_config(config_path):
def get_git_branch(): def get_git_branch():
try: try:
out = subprocess.check_output(["git", "branch"]).decode("utf8") out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split( current = next(line for line in out.split("\n")
"\n") if line.startswith("*")) if line.startswith("*"))
current.replace("* ", "") current.replace("* ", "")
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
current = "inside_docker" current = "inside_docker"
@ -48,8 +48,8 @@ def get_commit_hash():
# raise RuntimeError( # raise RuntimeError(
# " !! Commit before training to get the commit hash.") # " !! Commit before training to get the commit hash.")
try: try:
commit = subprocess.check_output(['git', 'rev-parse', '--short', commit = subprocess.check_output(
'HEAD']).decode().strip() ['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
# Not copying .git folder into docker container # Not copying .git folder into docker container
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
commit = "0000000" commit = "0000000"
@ -169,17 +169,43 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr return lr
def weight_decay(optimizer, wd): def weight_decay(optimizer):
""" """
Custom weight decay operation, not effecting grad values. Custom weight decay operation, not effecting grad values.
""" """
for group in optimizer.param_groups: for group in optimizer.param_groups:
for param in group['params']: for param in group['params']:
current_lr = group['lr'] current_lr = group['lr']
param.data = param.data.add(-wd * group['lr'], param.data) weight_decay = group['weight_decay']
param.data = param.data.add(-weight_decay * group['lr'],
param.data)
return optimizer, current_lr return optimizer, current_lr
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
"""
Skip biases, BatchNorm parameters for weight decay
and attention projection layer v
"""
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name in skip_list:
print(name)
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
class NoamLR(torch.optim.lr_scheduler._LRScheduler): class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps) self.warmup_steps = float(warmup_steps)
@ -188,8 +214,8 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def get_lr(self): def get_lr(self):
step = max(self.last_epoch, 1) step = max(self.last_epoch, 1)
return [ return [
base_lr * self.warmup_steps**0.5 * min( base_lr * self.warmup_steps**0.5 *
step * self.warmup_steps**-1.5, step**-0.5) min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs for base_lr in self.base_lrs
] ]
@ -244,8 +270,8 @@ def set_init_dict(model_dict, checkpoint, c):
} }
# 4. overwrite entries in the existing state dict # 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format( print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
len(pretrained_dict), len(model_dict))) len(model_dict)))
return model_dict return model_dict
@ -254,37 +280,35 @@ def setup_model(num_chars, num_speakers, c):
MyModel = importlib.import_module('TTS.models.' + c.model.lower()) MyModel = importlib.import_module('TTS.models.' + c.model.lower())
MyModel = getattr(MyModel, c.model) MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron": if c.model.lower() in "tacotron":
model = MyModel( model = MyModel(num_chars=num_chars,
num_chars=num_chars, num_speakers=num_speakers,
num_speakers=num_speakers, r=c.r,
r=c.r, linear_dim=1025,
linear_dim=1025, mel_dim=80,
mel_dim=80, gst=c.use_gst,
gst=c.use_gst, memory_size=c.memory_size,
memory_size=c.memory_size, attn_win=c.windowing,
attn_win=c.windowing, attn_norm=c.attention_norm,
attn_norm=c.attention_norm, prenet_type=c.prenet_type,
prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout,
prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn,
forward_attn=c.use_forward_attn, trans_agent=c.transition_agent,
trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask,
forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn,
location_attn=c.location_attn, separate_stopnet=c.separate_stopnet)
separate_stopnet=c.separate_stopnet)
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel( model = MyModel(num_chars=num_chars,
num_chars=num_chars, num_speakers=num_speakers,
num_speakers=num_speakers, r=c.r,
r=c.r, attn_win=c.windowing,
attn_win=c.windowing, attn_norm=c.attention_norm,
attn_norm=c.attention_norm, prenet_type=c.prenet_type,
prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout,
prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn,
forward_attn=c.use_forward_attn, trans_agent=c.transition_agent,
trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask,
forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn,
location_attn=c.location_attn, separate_stopnet=c.separate_stopnet)
separate_stopnet=c.separate_stopnet)
return model return model
@ -292,7 +316,8 @@ def split_dataset(items):
is_multi_speaker = False is_multi_speaker = False
speakers = [item[-1] for item in items] speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1 is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01) eval_split_size = 500 if 500 < len(items) * 0.01 else int(
len(items) * 0.01)
np.random.seed(0) np.random.seed(0)
np.random.shuffle(items) np.random.shuffle(items)
if is_multi_speaker: if is_multi_speaker: