mirror of https://github.com/coqui-ai/TTS.git
Update GAN model
This commit is contained in:
parent
2829027d8b
commit
13482dde1f
|
@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
|||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor=None):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
|
@ -28,6 +28,7 @@ class GAN(BaseVocoder):
|
|||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
|
||||
|
||||
Examples:
|
||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||
|
@ -41,6 +42,7 @@ class GAN(BaseVocoder):
|
|||
self.model_d = setup_discriminator(config)
|
||||
self.train_disc = False # if False, train only the generator.
|
||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
self.ap = ap
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's forward pass.
|
||||
|
@ -201,10 +203,9 @@ class GAN(BaseVocoder):
|
|||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._log("eval", ap, batch, outputs)
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
|
@ -215,10 +216,9 @@ class GAN(BaseVocoder):
|
|||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for evaluation."""
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._log("eval", ap, batch, outputs)
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
|
@ -330,12 +330,11 @@ class GAN(BaseVocoder):
|
|||
Returns:
|
||||
DataLoader: Torch dataloader.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
ap=self.ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
hop_len=self.ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
|
@ -363,5 +362,6 @@ class GAN(BaseVocoder):
|
|||
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit) -> "GAN":
|
||||
return GAN(config)
|
||||
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
return GAN(config, ap=ap)
|
||||
|
|
Loading…
Reference in New Issue