reformatting vocoder training and disable G_adv eval if step < d_training

This commit is contained in:
erogol 2020-06-11 12:42:13 +02:00
parent b8766c05ce
commit 9e71d4d66d
1 changed files with 25 additions and 22 deletions

View File

@ -112,7 +112,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
##############################
# generator pass
optimizer_G.zero_grad()
y_hat = model_G(c_G)
y_hat_sub = None
y_G_sub = None
@ -121,8 +120,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_hat_vis = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
if global_step > c.steps_to_start_discriminator:
@ -142,12 +141,11 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
scores_real, feats_real = None, None
feats_real = None
else:
scores_real, feats_real = D_out_real
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
scores_real = D_out_real
else:
scores_fake, feats_fake, feats_real = None, None, None
@ -157,6 +155,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
loss_G = loss_G_dict['G_loss']
# optimizer generator
optimizer_G.zero_grad()
loss_G.backward()
if c.gen_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
@ -184,8 +183,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
if y_hat.shape[1] > 1:
y_hat = model_G.pqmf_synthesis(y_hat)
optimizer_D.zero_grad()
# run D with or without cond. features
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat.detach(), c_D)
@ -210,6 +207,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
loss_D = loss_D_dict['D_loss']
# optimizer discriminator
optimizer_D.zero_grad()
loss_D.backward()
if c.disc_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
@ -326,25 +324,30 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat, c_G)
else:
D_out_fake = model_D(y_hat)
D_out_real = None
if c.use_feat_match_loss:
with torch.no_grad():
D_out_real = model_D(y_G)
if global_step > c.steps_to_start_discriminator:
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat, c_G)
else:
_, feats_real = D_out_real
D_out_fake = model_D(y_hat)
D_out_real = None
if c.use_feat_match_loss:
with torch.no_grad():
D_out_real = model_D(y_G)
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
else:
scores_fake = D_out_fake
scores_fake, feats_fake, feats_real = None, None, None
# compute losses
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,