mirror of https://github.com/coqui-ai/TTS.git
Weight decay described here: http://www.fast.ai/2018/07/02/adam-weight-decay/
This commit is contained in:
parent
9b29b4e281
commit
16db5159f1
10
train.py
10
train.py
|
@ -89,6 +89,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||||
grad_norm, skip_flag = check_update(model, 1)
|
grad_norm, skip_flag = check_update(model, 1)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -98,6 +101,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
|
for group in optimizer_st.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||||
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
|
@ -390,9 +396,9 @@ def main(args):
|
||||||
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=c.wd)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
optimizer_st = optim.Adam(
|
optimizer_st = optim.Adam(
|
||||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=c.wd)
|
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
criterion = L1LossMasked()
|
criterion = L1LossMasked()
|
||||||
criterion_st = nn.BCELoss()
|
criterion_st = nn.BCELoss()
|
||||||
|
|
Loading…
Reference in New Issue