mirror of https://github.com/coqui-ai/TTS.git
Fix GAN optimizer order
commit212d330929
Author: Edresson Casanova <edresson1@gmail.com> Date: Fri Apr 29 16:29:44 2022 -0300 Fix unit test commit44456b0483
Author: Edresson Casanova <edresson1@gmail.com> Date: Fri Apr 29 07:28:39 2022 -0300 Fix style commitd545beadb9
Author: Edresson Casanova <edresson1@gmail.com> Date: Thu Apr 28 17:08:04 2022 -0300 Change order of HIFI-GAN optimizers to be equal than the original repository commit657c5442e5
Author: Edresson Casanova <edresson1@gmail.com> Date: Thu Apr 28 15:40:16 2022 -0300 Remove audio padding before mel spec extraction commit76b274e690
Merge:379ccd7b
6233f4fc
Author: Edresson Casanova <edresson1@gmail.com> Date: Wed Apr 27 07:28:48 2022 -0300 Merge pull request #1541 from coqui-ai/comp_emb_fix Bug fix in compute embedding without eval partition commit379ccd7ba6
Author: WeberJulian <julian.weber@hotmail.fr> Date: Wed Apr 27 10:42:26 2022 +0200 returns y_mask in VITS inference (#1540) * returns y_mask * make style
This commit is contained in:
parent
60034674f9
commit
a0a9279e4b
|
@ -90,50 +90,26 @@ class GAN(BaseVocoder):
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
# GENERATOR
|
# DISCRIMINATOR optimization
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
||||||
self.y_hat_g = y_hat # save for discriminator
|
|
||||||
y_hat_sub = None
|
# cache for generator loss
|
||||||
y_sub = None
|
# pylint: disable=W0201
|
||||||
|
self.y_hat_g = y_hat
|
||||||
|
self.y_hat_sub = None
|
||||||
|
self.y_sub_g = None
|
||||||
|
|
||||||
# PQMF formatting
|
# PQMF formatting
|
||||||
if y_hat.shape[1] > 1:
|
if y_hat.shape[1] > 1:
|
||||||
y_hat_sub = y_hat
|
self.y_hat_sub = y_hat
|
||||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||||
self.y_hat_g = y_hat # save for discriminator
|
self.y_hat_g = y_hat # save for generator loss
|
||||||
y_sub = self.model_g.pqmf_analysis(y)
|
self.y_sub_g = self.model_g.pqmf_analysis(y)
|
||||||
|
|
||||||
scores_fake, feats_fake, feats_real = None, None, None
|
scores_fake, feats_fake, feats_real = None, None, None
|
||||||
if self.train_disc:
|
|
||||||
|
|
||||||
if len(signature(self.model_d.forward).parameters) == 2:
|
|
||||||
D_out_fake = self.model_d(y_hat, x)
|
|
||||||
else:
|
|
||||||
D_out_fake = self.model_d(y_hat)
|
|
||||||
D_out_real = None
|
|
||||||
|
|
||||||
if self.config.use_feat_match_loss:
|
|
||||||
with torch.no_grad():
|
|
||||||
D_out_real = self.model_d(y)
|
|
||||||
|
|
||||||
# format D outputs
|
|
||||||
if isinstance(D_out_fake, tuple):
|
|
||||||
scores_fake, feats_fake = D_out_fake
|
|
||||||
if D_out_real is None:
|
|
||||||
feats_real = None
|
|
||||||
else:
|
|
||||||
_, feats_real = D_out_real
|
|
||||||
else:
|
|
||||||
scores_fake = D_out_fake
|
|
||||||
feats_fake, feats_real = None, None
|
|
||||||
|
|
||||||
# compute losses
|
|
||||||
loss_dict = criterion[optimizer_idx](y_hat, y, scores_fake, feats_fake, feats_real, y_hat_sub, y_sub)
|
|
||||||
outputs = {"model_outputs": y_hat}
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# DISCRIMINATOR
|
|
||||||
if self.train_disc:
|
if self.train_disc:
|
||||||
# use different samples for G and D trainings
|
# use different samples for G and D trainings
|
||||||
if self.config.diff_samples_for_G_and_D:
|
if self.config.diff_samples_for_G_and_D:
|
||||||
|
@ -177,6 +153,36 @@ class GAN(BaseVocoder):
|
||||||
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
||||||
outputs = {"model_outputs": y_hat}
|
outputs = {"model_outputs": y_hat}
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# GENERATOR loss
|
||||||
|
scores_fake, feats_fake, feats_real = None, None, None
|
||||||
|
if self.train_disc:
|
||||||
|
if len(signature(self.model_d.forward).parameters) == 2:
|
||||||
|
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||||
|
else:
|
||||||
|
D_out_fake = self.model_d(self.y_hat_g)
|
||||||
|
D_out_real = None
|
||||||
|
|
||||||
|
if self.config.use_feat_match_loss:
|
||||||
|
with torch.no_grad():
|
||||||
|
D_out_real = self.model_d(y)
|
||||||
|
|
||||||
|
# format D outputs
|
||||||
|
if isinstance(D_out_fake, tuple):
|
||||||
|
scores_fake, feats_fake = D_out_fake
|
||||||
|
if D_out_real is None:
|
||||||
|
feats_real = None
|
||||||
|
else:
|
||||||
|
_, feats_real = D_out_real
|
||||||
|
else:
|
||||||
|
scores_fake = D_out_fake
|
||||||
|
feats_fake, feats_real = None, None
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
loss_dict = criterion[optimizer_idx](
|
||||||
|
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||||
|
)
|
||||||
|
outputs = {"model_outputs": self.y_hat_g}
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -210,6 +216,7 @@ class GAN(BaseVocoder):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
"""Call `train_step()` with `no_grad()`"""
|
"""Call `train_step()` with `no_grad()`"""
|
||||||
|
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(
|
def eval_log(
|
||||||
|
@ -266,7 +273,7 @@ class GAN(BaseVocoder):
|
||||||
optimizer2 = get_optimizer(
|
optimizer2 = get_optimizer(
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
||||||
)
|
)
|
||||||
return [optimizer1, optimizer2]
|
return [optimizer2, optimizer1]
|
||||||
|
|
||||||
def get_lr(self) -> List:
|
def get_lr(self) -> List:
|
||||||
"""Set the initial learning rates for each optimizer.
|
"""Set the initial learning rates for each optimizer.
|
||||||
|
@ -274,7 +281,7 @@ class GAN(BaseVocoder):
|
||||||
Returns:
|
Returns:
|
||||||
List: learning rates for each optimizer.
|
List: learning rates for each optimizer.
|
||||||
"""
|
"""
|
||||||
return [self.config.lr_gen, self.config.lr_disc]
|
return [self.config.lr_disc, self.config.lr_gen]
|
||||||
|
|
||||||
def get_scheduler(self, optimizer) -> List:
|
def get_scheduler(self, optimizer) -> List:
|
||||||
"""Set the schedulers for each optimizer.
|
"""Set the schedulers for each optimizer.
|
||||||
|
@ -287,7 +294,7 @@ class GAN(BaseVocoder):
|
||||||
"""
|
"""
|
||||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||||
return [scheduler1, scheduler2]
|
return [scheduler2, scheduler1]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_batch(batch: List) -> Dict:
|
def format_batch(batch: List) -> Dict:
|
||||||
|
@ -359,7 +366,7 @@ class GAN(BaseVocoder):
|
||||||
|
|
||||||
def get_criterion(self):
|
def get_criterion(self):
|
||||||
"""Return criterions for the optimizers"""
|
"""Return criterions for the optimizers"""
|
||||||
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
|
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||||
|
|
Loading…
Reference in New Issue