mirror of https://github.com/coqui-ai/TTS.git
skip weight decay for BN and biases, some formatting
This commit is contained in:
parent
53d658fb74
commit
b76aaf8ad4
10
train.py
10
train.py
|
@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
load_config, remove_experiment_folder,
|
load_config, remove_experiment_folder,
|
||||||
save_best_model, save_checkpoint, weight_decay,
|
save_best_model, save_checkpoint, weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model,
|
set_init_dict, copy_config_file, setup_model,
|
||||||
split_dataset, gradual_training_scheduler, KeepAverage)
|
split_dataset, gradual_training_scheduler, KeepAverage,
|
||||||
|
set_weight_decay)
|
||||||
from TTS.utils.logger import Logger
|
from TTS.utils.logger import Logger
|
||||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
get_speakers
|
get_speakers
|
||||||
|
@ -186,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
loss += stop_loss
|
loss += stop_loss
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
optimizer, current_lr = weight_decay(optimizer)
|
||||||
grad_norm, _ = check_update(model, c.grad_clip)
|
grad_norm, _ = check_update(model, c.grad_clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
@ -197,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
if c.separate_stopnet:
|
if c.separate_stopnet:
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
optimizer_st, _ = weight_decay(optimizer_st)
|
||||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||||
optimizer_st.step()
|
optimizer_st.step()
|
||||||
else:
|
else:
|
||||||
|
@ -511,7 +512,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
|
params = set_weight_decay(model, c.wd)
|
||||||
|
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||||
if c.stopnet and c.separate_stopnet:
|
if c.stopnet and c.separate_stopnet:
|
||||||
optimizer_st = RAdam(
|
optimizer_st = RAdam(
|
||||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
|
@ -31,8 +31,8 @@ def load_config(config_path):
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
current = next(line for line in out.split(
|
current = next(line for line in out.split("\n")
|
||||||
"\n") if line.startswith("*"))
|
if line.startswith("*"))
|
||||||
current.replace("* ", "")
|
current.replace("* ", "")
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
current = "inside_docker"
|
current = "inside_docker"
|
||||||
|
@ -48,8 +48,8 @@ def get_commit_hash():
|
||||||
# raise RuntimeError(
|
# raise RuntimeError(
|
||||||
# " !! Commit before training to get the commit hash.")
|
# " !! Commit before training to get the commit hash.")
|
||||||
try:
|
try:
|
||||||
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
commit = subprocess.check_output(
|
||||||
'HEAD']).decode().strip()
|
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||||
# Not copying .git folder into docker container
|
# Not copying .git folder into docker container
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
commit = "0000000"
|
commit = "0000000"
|
||||||
|
@ -169,17 +169,43 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
def weight_decay(optimizer, wd):
|
def weight_decay(optimizer):
|
||||||
"""
|
"""
|
||||||
Custom weight decay operation, not effecting grad values.
|
Custom weight decay operation, not effecting grad values.
|
||||||
"""
|
"""
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
for param in group['params']:
|
for param in group['params']:
|
||||||
current_lr = group['lr']
|
current_lr = group['lr']
|
||||||
param.data = param.data.add(-wd * group['lr'], param.data)
|
weight_decay = group['weight_decay']
|
||||||
|
param.data = param.data.add(-weight_decay * group['lr'],
|
||||||
|
param.data)
|
||||||
return optimizer, current_lr
|
return optimizer, current_lr
|
||||||
|
|
||||||
|
|
||||||
|
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
||||||
|
"""
|
||||||
|
Skip biases, BatchNorm parameters for weight decay
|
||||||
|
and attention projection layer v
|
||||||
|
"""
|
||||||
|
decay = []
|
||||||
|
no_decay = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if len(param.shape) == 1 or name in skip_list:
|
||||||
|
print(name)
|
||||||
|
no_decay.append(param)
|
||||||
|
else:
|
||||||
|
decay.append(param)
|
||||||
|
return [{
|
||||||
|
'params': no_decay,
|
||||||
|
'weight_decay': 0.
|
||||||
|
}, {
|
||||||
|
'params': decay,
|
||||||
|
'weight_decay': weight_decay
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||||
self.warmup_steps = float(warmup_steps)
|
self.warmup_steps = float(warmup_steps)
|
||||||
|
@ -188,8 +214,8 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
step = max(self.last_epoch, 1)
|
step = max(self.last_epoch, 1)
|
||||||
return [
|
return [
|
||||||
base_lr * self.warmup_steps**0.5 * min(
|
base_lr * self.warmup_steps**0.5 *
|
||||||
step * self.warmup_steps**-1.5, step**-0.5)
|
min(step * self.warmup_steps**-1.5, step**-0.5)
|
||||||
for base_lr in self.base_lrs
|
for base_lr in self.base_lrs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -244,8 +270,8 @@ def set_init_dict(model_dict, checkpoint, c):
|
||||||
}
|
}
|
||||||
# 4. overwrite entries in the existing state dict
|
# 4. overwrite entries in the existing state dict
|
||||||
model_dict.update(pretrained_dict)
|
model_dict.update(pretrained_dict)
|
||||||
print(" | > {} / {} layers are restored.".format(
|
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||||
len(pretrained_dict), len(model_dict)))
|
len(model_dict)))
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -254,37 +280,35 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
||||||
MyModel = getattr(MyModel, c.model)
|
MyModel = getattr(MyModel, c.model)
|
||||||
if c.model.lower() in "tacotron":
|
if c.model.lower() in "tacotron":
|
||||||
model = MyModel(
|
model = MyModel(num_chars=num_chars,
|
||||||
num_chars=num_chars,
|
num_speakers=num_speakers,
|
||||||
num_speakers=num_speakers,
|
r=c.r,
|
||||||
r=c.r,
|
linear_dim=1025,
|
||||||
linear_dim=1025,
|
mel_dim=80,
|
||||||
mel_dim=80,
|
gst=c.use_gst,
|
||||||
gst=c.use_gst,
|
memory_size=c.memory_size,
|
||||||
memory_size=c.memory_size,
|
attn_win=c.windowing,
|
||||||
attn_win=c.windowing,
|
attn_norm=c.attention_norm,
|
||||||
attn_norm=c.attention_norm,
|
prenet_type=c.prenet_type,
|
||||||
prenet_type=c.prenet_type,
|
prenet_dropout=c.prenet_dropout,
|
||||||
prenet_dropout=c.prenet_dropout,
|
forward_attn=c.use_forward_attn,
|
||||||
forward_attn=c.use_forward_attn,
|
trans_agent=c.transition_agent,
|
||||||
trans_agent=c.transition_agent,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
location_attn=c.location_attn,
|
||||||
location_attn=c.location_attn,
|
separate_stopnet=c.separate_stopnet)
|
||||||
separate_stopnet=c.separate_stopnet)
|
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(
|
model = MyModel(num_chars=num_chars,
|
||||||
num_chars=num_chars,
|
num_speakers=num_speakers,
|
||||||
num_speakers=num_speakers,
|
r=c.r,
|
||||||
r=c.r,
|
attn_win=c.windowing,
|
||||||
attn_win=c.windowing,
|
attn_norm=c.attention_norm,
|
||||||
attn_norm=c.attention_norm,
|
prenet_type=c.prenet_type,
|
||||||
prenet_type=c.prenet_type,
|
prenet_dropout=c.prenet_dropout,
|
||||||
prenet_dropout=c.prenet_dropout,
|
forward_attn=c.use_forward_attn,
|
||||||
forward_attn=c.use_forward_attn,
|
trans_agent=c.transition_agent,
|
||||||
trans_agent=c.transition_agent,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
location_attn=c.location_attn,
|
||||||
location_attn=c.location_attn,
|
separate_stopnet=c.separate_stopnet)
|
||||||
separate_stopnet=c.separate_stopnet)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -292,7 +316,8 @@ def split_dataset(items):
|
||||||
is_multi_speaker = False
|
is_multi_speaker = False
|
||||||
speakers = [item[-1] for item in items]
|
speakers = [item[-1] for item in items]
|
||||||
is_multi_speaker = len(set(speakers)) > 1
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
|
eval_split_size = 500 if 500 < len(items) * 0.01 else int(
|
||||||
|
len(items) * 0.01)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np.random.shuffle(items)
|
np.random.shuffle(items)
|
||||||
if is_multi_speaker:
|
if is_multi_speaker:
|
||||||
|
|
Loading…
Reference in New Issue