diff --git a/train.py b/train.py index 13444c82..7a68e2b0 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, remove_experiment_folder, save_best_model, save_checkpoint, weight_decay, set_init_dict, copy_config_file, setup_model, - split_dataset, gradual_training_scheduler, KeepAverage) + split_dataset, gradual_training_scheduler, KeepAverage, + set_weight_decay) from TTS.utils.logger import Logger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers @@ -186,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, loss += stop_loss loss.backward() - optimizer, current_lr = weight_decay(optimizer, c.wd) + optimizer, current_lr = weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() @@ -197,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # backpass and check the grad norm for stop loss if c.separate_stopnet: stop_loss.backward() - optimizer_st, _ = weight_decay(optimizer_st, c.wd) + optimizer_st, _ = weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: @@ -511,7 +512,8 @@ def main(args): # pylint: disable=redefined-outer-name print(" | > Num output units : {}".format(ap.num_freq), flush=True) - optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0) + params = set_weight_decay(model, c.wd) + optimizer = RAdam(params, lr=c.lr, weight_decay=0) if c.stopnet and c.separate_stopnet: optimizer_st = RAdam( model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index bfa72a35..3cdf74bc 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -31,8 +31,8 @@ def load_config(config_path): def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split( - "\n") if line.startswith("*")) + current = next(line for line in out.split("\n") + if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" @@ -48,8 +48,8 @@ def get_commit_hash(): # raise RuntimeError( # " !! Commit before training to get the commit hash.") try: - commit = subprocess.check_output(['git', 'rev-parse', '--short', - 'HEAD']).decode().strip() + commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() # Not copying .git folder into docker container except subprocess.CalledProcessError: commit = "0000000" @@ -169,17 +169,43 @@ def lr_decay(init_lr, global_step, warmup_steps): return lr -def weight_decay(optimizer, wd): +def weight_decay(optimizer): """ Custom weight decay operation, not effecting grad values. """ for group in optimizer.param_groups: for param in group['params']: current_lr = group['lr'] - param.data = param.data.add(-wd * group['lr'], param.data) + weight_decay = group['weight_decay'] + param.data = param.data.add(-weight_decay * group['lr'], + param.data) return optimizer, current_lr +def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}): + """ + Skip biases, BatchNorm parameters for weight decay + and attention projection layer v + """ + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if len(param.shape) == 1 or name in skip_list: + print(name) + no_decay.append(param) + else: + decay.append(param) + return [{ + 'params': no_decay, + 'weight_decay': 0. + }, { + 'params': decay, + 'weight_decay': weight_decay + }] + + class NoamLR(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): self.warmup_steps = float(warmup_steps) @@ -188,8 +214,8 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler): def get_lr(self): step = max(self.last_epoch, 1) return [ - base_lr * self.warmup_steps**0.5 * min( - step * self.warmup_steps**-1.5, step**-0.5) + base_lr * self.warmup_steps**0.5 * + min(step * self.warmup_steps**-1.5, step**-0.5) for base_lr in self.base_lrs ] @@ -244,8 +270,8 @@ def set_init_dict(model_dict, checkpoint, c): } # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format( - len(pretrained_dict), len(model_dict))) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), + len(model_dict))) return model_dict @@ -254,37 +280,35 @@ def setup_model(num_chars, num_speakers, c): MyModel = importlib.import_module('TTS.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) if c.model.lower() in "tacotron": - model = MyModel( - num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - linear_dim=1025, - mel_dim=80, - gst=c.use_gst, - memory_size=c.memory_size, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - separate_stopnet=c.separate_stopnet) + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + linear_dim=1025, + mel_dim=80, + gst=c.use_gst, + memory_size=c.memory_size, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + separate_stopnet=c.separate_stopnet) elif c.model.lower() == "tacotron2": - model = MyModel( - num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - separate_stopnet=c.separate_stopnet) + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + separate_stopnet=c.separate_stopnet) return model @@ -292,7 +316,8 @@ def split_dataset(items): is_multi_speaker = False speakers = [item[-1] for item in items] is_multi_speaker = len(set(speakers)) > 1 - eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01) + eval_split_size = 500 if 500 < len(items) * 0.01 else int( + len(items) * 0.01) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: