Update GAN model

This commit is contained in:
Eren Gölge 2022-01-25 09:22:35 +00:00
parent 2829027d8b
commit 13482dde1f
1 changed files with 12 additions and 12 deletions
TTS/vocoder/models

View File

@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder): class GAN(BaseVocoder):
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit, ap: AudioProcessor=None):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily. It also helps mixing and matching different generator and disciminator networks easily.
@ -28,6 +28,7 @@ class GAN(BaseVocoder):
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
Examples: Examples:
Initializing the GAN model with HifiGAN generator and discriminator. Initializing the GAN model with HifiGAN generator and discriminator.
@ -41,6 +42,7 @@ class GAN(BaseVocoder):
self.model_d = setup_discriminator(config) self.model_d = setup_discriminator(config)
self.train_disc = False # if False, train only the generator. self.train_disc = False # if False, train only the generator.
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
self.ap = ap
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass. """Run the generator's forward pass.
@ -201,10 +203,9 @@ class GAN(BaseVocoder):
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]: ) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training.""" """Call `_log()` for training."""
ap = assets["audio_processor"] figures, audios = self._log("eval", self.ap, batch, outputs)
figures, audios = self._log("eval", ap, batch, outputs)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
@ -215,10 +216,9 @@ class GAN(BaseVocoder):
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]: ) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation.""" """Call `_log()` for evaluation."""
ap = assets["audio_processor"] figures, audios = self._log("eval", self.ap, batch, outputs)
figures, audios = self._log("eval", ap, batch, outputs)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, self,
@ -330,12 +330,11 @@ class GAN(BaseVocoder):
Returns: Returns:
DataLoader: Torch dataloader. DataLoader: Torch dataloader.
""" """
ap = assets["audio_processor"]
dataset = GANDataset( dataset = GANDataset(
ap=ap, ap=self.ap,
items=data_items, items=data_items,
seq_len=config.seq_len, seq_len=config.seq_len,
hop_len=ap.hop_length, hop_len=self.ap.hop_length,
pad_short=config.pad_short, pad_short=config.pad_short,
conv_pad=config.conv_pad, conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
@ -363,5 +362,6 @@ class GAN(BaseVocoder):
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)] return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
@staticmethod @staticmethod
def init_from_config(config: Coqpit) -> "GAN": def init_from_config(config: Coqpit, verbose=True) -> "GAN":
return GAN(config) ap = AudioProcessor.init_from_config(config, verbose=verbose)
return GAN(config, ap=ap)