mirror of https://github.com/coqui-ai/TTS.git
Update GAN model
This commit is contained in:
parent
2829027d8b
commit
13482dde1f
|
@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
|
||||||
class GAN(BaseVocoder):
|
class GAN(BaseVocoder):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit, ap: AudioProcessor=None):
|
||||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||||
It also helps mixing and matching different generator and disciminator networks easily.
|
It also helps mixing and matching different generator and disciminator networks easily.
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ class GAN(BaseVocoder):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
|
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||||
|
@ -41,6 +42,7 @@ class GAN(BaseVocoder):
|
||||||
self.model_d = setup_discriminator(config)
|
self.model_d = setup_discriminator(config)
|
||||||
self.train_disc = False # if False, train only the generator.
|
self.train_disc = False # if False, train only the generator.
|
||||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Run the generator's forward pass.
|
"""Run the generator's forward pass.
|
||||||
|
@ -201,10 +203,9 @@ class GAN(BaseVocoder):
|
||||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||||
) -> Tuple[Dict, np.ndarray]:
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for training."""
|
"""Call `_log()` for training."""
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||||
figures, audios = self._log("eval", ap, batch, outputs)
|
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
|
@ -215,10 +216,9 @@ class GAN(BaseVocoder):
|
||||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||||
) -> Tuple[Dict, np.ndarray]:
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for evaluation."""
|
"""Call `_log()` for evaluation."""
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||||
figures, audios = self._log("eval", ap, batch, outputs)
|
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
|
@ -330,12 +330,11 @@ class GAN(BaseVocoder):
|
||||||
Returns:
|
Returns:
|
||||||
DataLoader: Torch dataloader.
|
DataLoader: Torch dataloader.
|
||||||
"""
|
"""
|
||||||
ap = assets["audio_processor"]
|
|
||||||
dataset = GANDataset(
|
dataset = GANDataset(
|
||||||
ap=ap,
|
ap=self.ap,
|
||||||
items=data_items,
|
items=data_items,
|
||||||
seq_len=config.seq_len,
|
seq_len=config.seq_len,
|
||||||
hop_len=ap.hop_length,
|
hop_len=self.ap.hop_length,
|
||||||
pad_short=config.pad_short,
|
pad_short=config.pad_short,
|
||||||
conv_pad=config.conv_pad,
|
conv_pad=config.conv_pad,
|
||||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||||
|
@ -363,5 +362,6 @@ class GAN(BaseVocoder):
|
||||||
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
|
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: Coqpit) -> "GAN":
|
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||||
return GAN(config)
|
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||||
|
return GAN(config, ap=ap)
|
||||||
|
|
Loading…
Reference in New Issue