Update GAN for Trainer_v2

This commit is contained in:
Eren Gölge 2021-09-30 14:20:30 +00:00
parent a156a40b47
commit 4baecdf92a
4 changed files with 91 additions and 19 deletions

View File

@ -35,7 +35,7 @@ class GAN(BaseVocoder):
>>> config = HifiganConfig()
>>> model = GAN(config)
"""
super().__init__()
super().__init__(config)
self.config = config
self.model_g = setup_generator(config)
self.model_d = setup_discriminator(config)
@ -197,18 +197,24 @@ class GAN(BaseVocoder):
audios = {f"{name}/audio": sample_voice}
return figures, audios
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def train_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
return self._log("train", ap, batch, outputs)
ap = assets["audio_processor"]
self._log("train", ap, batch, outputs)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def eval_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
return self._log("eval", ap, batch, outputs)
ap = assets["audio_processor"]
self._log("eval", ap, batch, outputs)
def load_checkpoint(
self,
@ -299,7 +305,7 @@ class GAN(BaseVocoder):
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
ap: AudioProcessor,
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
@ -318,6 +324,7 @@ class GAN(BaseVocoder):
Returns:
DataLoader: Torch dataloader.
"""
ap = assets["audio_processor"]
dataset = GANDataset(
ap=ap,
items=data_items,

View File

@ -1,29 +1,51 @@
import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN
output_path = os.path.dirname(os.path.abspath(__file__))
config = HifiganConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
test_delay_epochs=5,
epochs=1000,
seq_len=8192,
pad_short=2000,
use_noise_augment=True,
eval_split_size=10,
print_step=25,
print_eval=True,
print_eval=False,
mixed_precision=False,
lr_gen=1e-4,
lr_disc=1e-4,
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -1,29 +1,51 @@
import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import MultibandMelganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN
output_path = os.path.dirname(os.path.abspath(__file__))
config = MultibandMelganConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
test_delay_epochs=5,
epochs=1000,
seq_len=8192,
pad_short=2000,
use_noise_augment=True,
eval_split_size=10,
print_step=25,
print_eval=True,
print_eval=False,
mixed_precision=False,
lr_gen=1e-4,
lr_disc=1e-4,
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -1,7 +1,10 @@
import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import UnivnetConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN
output_path = os.path.dirname(os.path.abspath(__file__))
config = UnivnetConfig(
@ -24,6 +27,24 @@ config = UnivnetConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()