Update GAN model

This commit is contained in:
Eren Gölge 2022-01-25 09:22:35 +00:00
parent 2829027d8b
commit 13482dde1f
1 changed files with 12 additions and 12 deletions

View File

@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder):
def __init__(self, config: Coqpit):
def __init__(self, config: Coqpit, ap: AudioProcessor=None):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
@ -28,6 +28,7 @@ class GAN(BaseVocoder):
Args:
config (Coqpit): Model configuration.
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
Examples:
Initializing the GAN model with HifiGAN generator and discriminator.
@ -41,6 +42,7 @@ class GAN(BaseVocoder):
self.model_d = setup_discriminator(config)
self.train_disc = False # if False, train only the generator.
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
self.ap = ap
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass.
@ -201,10 +203,9 @@ class GAN(BaseVocoder):
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
ap = assets["audio_processor"]
figures, audios = self._log("eval", ap, batch, outputs)
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
@ -215,10 +216,9 @@ class GAN(BaseVocoder):
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
ap = assets["audio_processor"]
figures, audios = self._log("eval", ap, batch, outputs)
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self,
@ -330,12 +330,11 @@ class GAN(BaseVocoder):
Returns:
DataLoader: Torch dataloader.
"""
ap = assets["audio_processor"]
dataset = GANDataset(
ap=ap,
ap=self.ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
hop_len=self.ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
@ -363,5 +362,6 @@ class GAN(BaseVocoder):
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
@staticmethod
def init_from_config(config: Coqpit) -> "GAN":
return GAN(config)
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose)
return GAN(config, ap=ap)