mirror of https://github.com/coqui-ai/TTS.git
Fix unit test
This commit is contained in:
parent
44456b0483
commit
212d330929
|
@ -155,6 +155,7 @@ class GAN(BaseVocoder):
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
if optimizer_idx == 1:
|
||||||
# GENERATOR loss
|
# GENERATOR loss
|
||||||
|
scores_fake, feats_fake, feats_real = None, None, None
|
||||||
if self.train_disc:
|
if self.train_disc:
|
||||||
if len(signature(self.model_d.forward).parameters) == 2:
|
if len(signature(self.model_d.forward).parameters) == 2:
|
||||||
D_out_fake = self.model_d(self.y_hat_g, x)
|
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||||
|
@ -182,7 +183,6 @@ class GAN(BaseVocoder):
|
||||||
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||||
)
|
)
|
||||||
outputs = {"model_outputs": self.y_hat_g}
|
outputs = {"model_outputs": self.y_hat_g}
|
||||||
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -216,6 +216,7 @@ class GAN(BaseVocoder):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
"""Call `train_step()` with `no_grad()`"""
|
"""Call `train_step()` with `no_grad()`"""
|
||||||
|
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(
|
def eval_log(
|
||||||
|
|
Loading…
Reference in New Issue