mirror of https://github.com/coqui-ai/TTS.git
reformatting vocoder training and disable G_adv eval if step < d_training
This commit is contained in:
parent
b8766c05ce
commit
9e71d4d66d
|
@ -112,7 +112,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
##############################
|
||||
|
||||
# generator pass
|
||||
optimizer_G.zero_grad()
|
||||
y_hat = model_G(c_G)
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
|
@ -121,8 +120,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_hat_vis = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
@ -142,12 +141,11 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
feats_real = None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
else:
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
|
@ -157,6 +155,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
loss_G = loss_G_dict['G_loss']
|
||||
|
||||
# optimizer generator
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
if c.gen_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||
|
@ -184,8 +183,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
|
||||
optimizer_D.zero_grad()
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_D)
|
||||
|
@ -210,6 +207,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
loss_D = loss_D_dict['D_loss']
|
||||
|
||||
# optimizer discriminator
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
if c.disc_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||
|
@ -326,25 +324,30 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
|
|
Loading…
Reference in New Issue