From 9e71d4d66d5126a13f881f42f6e80504ca0ee560 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 11 Jun 2020 12:42:13 +0200 Subject: [PATCH] reformatting vocoder training and disable G_adv eval if step < d_training --- vocoder/train.py | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/vocoder/train.py b/vocoder/train.py index ce8f111a..74b3e759 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -112,7 +112,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, ############################## # generator pass - optimizer_G.zero_grad() y_hat = model_G(c_G) y_hat_sub = None y_G_sub = None @@ -121,8 +120,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat - y_hat = model_G.pqmf_synthesis(y_hat) y_hat_vis = y_hat + y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) if global_step > c.steps_to_start_discriminator: @@ -142,12 +141,11 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: - scores_real, feats_real = None, None + feats_real = None else: - scores_real, feats_real = D_out_real + _, feats_real = D_out_real else: scores_fake = D_out_fake - scores_real = D_out_real else: scores_fake, feats_fake, feats_real = None, None, None @@ -157,6 +155,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, loss_G = loss_G_dict['G_loss'] # optimizer generator + optimizer_G.zero_grad() loss_G.backward() if c.gen_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_G.parameters(), @@ -184,8 +183,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, if y_hat.shape[1] > 1: y_hat = model_G.pqmf_synthesis(y_hat) - optimizer_D.zero_grad() - # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat.detach(), c_D) @@ -210,6 +207,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, loss_D = loss_D_dict['D_loss'] # optimizer discriminator + optimizer_D.zero_grad() loss_D.backward() if c.disc_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_D.parameters(), @@ -326,25 +324,30 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) - if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(y_hat, c_G) - else: - D_out_fake = model_D(y_hat) - D_out_real = None - if c.use_feat_match_loss: - with torch.no_grad(): - D_out_real = model_D(y_G) + if global_step > c.steps_to_start_discriminator: - # format D outputs - if isinstance(D_out_fake, tuple): - scores_fake, feats_fake = D_out_fake - if D_out_real is None: - feats_real = None + if len(signature(model_D.forward).parameters) == 2: + D_out_fake = model_D(y_hat, c_G) else: - _, feats_real = D_out_real + D_out_fake = model_D(y_hat) + D_out_real = None + + if c.use_feat_match_loss: + with torch.no_grad(): + D_out_real = model_D(y_G) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + feats_real = None + else: + _, feats_real = D_out_real + else: + scores_fake = D_out_fake else: - scores_fake = D_out_fake + scores_fake, feats_fake, feats_real = None, None, None # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,