mirror of https://github.com/coqui-ai/TTS.git
Update GAN for Trainer_v2
This commit is contained in:
parent
a156a40b47
commit
4baecdf92a
|
@ -35,7 +35,7 @@ class GAN(BaseVocoder):
|
||||||
>>> config = HifiganConfig()
|
>>> config = HifiganConfig()
|
||||||
>>> model = GAN(config)
|
>>> model = GAN(config)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_g = setup_generator(config)
|
self.model_g = setup_generator(config)
|
||||||
self.model_d = setup_discriminator(config)
|
self.model_d = setup_discriminator(config)
|
||||||
|
@ -197,18 +197,24 @@ class GAN(BaseVocoder):
|
||||||
audios = {f"{name}/audio": sample_voice}
|
audios = {f"{name}/audio": sample_voice}
|
||||||
return figures, audios
|
return figures, audios
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
def train_log(
|
||||||
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for training."""
|
"""Call `_log()` for training."""
|
||||||
return self._log("train", ap, batch, outputs)
|
ap = assets["audio_processor"]
|
||||||
|
self._log("train", ap, batch, outputs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
"""Call `train_step()` with `no_grad()`"""
|
"""Call `train_step()` with `no_grad()`"""
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
def eval_log(
|
||||||
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
"""Call `_log()` for evaluation."""
|
"""Call `_log()` for evaluation."""
|
||||||
return self._log("eval", ap, batch, outputs)
|
ap = assets["audio_processor"]
|
||||||
|
self._log("eval", ap, batch, outputs)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
|
@ -299,7 +305,7 @@ class GAN(BaseVocoder):
|
||||||
def get_data_loader( # pylint: disable=no-self-use
|
def get_data_loader( # pylint: disable=no-self-use
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
ap: AudioProcessor,
|
assets: Dict,
|
||||||
is_eval: True,
|
is_eval: True,
|
||||||
data_items: List,
|
data_items: List,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
@ -318,6 +324,7 @@ class GAN(BaseVocoder):
|
||||||
Returns:
|
Returns:
|
||||||
DataLoader: Torch dataloader.
|
DataLoader: Torch dataloader.
|
||||||
"""
|
"""
|
||||||
|
ap = assets["audio_processor"]
|
||||||
dataset = GANDataset(
|
dataset = GANDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
items=data_items,
|
items=data_items,
|
||||||
|
|
|
@ -1,29 +1,51 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.configs import HifiganConfig
|
from TTS.vocoder.configs import HifiganConfig
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
|
from TTS.vocoder.models.gan import GAN
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
config = HifiganConfig(
|
config = HifiganConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
num_loader_workers=4,
|
num_loader_workers=4,
|
||||||
num_eval_loader_workers=4,
|
num_eval_loader_workers=4,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=5,
|
||||||
epochs=1000,
|
epochs=1000,
|
||||||
seq_len=8192,
|
seq_len=8192,
|
||||||
pad_short=2000,
|
pad_short=2000,
|
||||||
use_noise_augment=True,
|
use_noise_augment=True,
|
||||||
eval_split_size=10,
|
eval_split_size=10,
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
lr_gen=1e-4,
|
lr_gen=1e-4,
|
||||||
lr_disc=1e-4,
|
lr_disc=1e-4,
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = GAN(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,29 +1,51 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.configs import MultibandMelganConfig
|
from TTS.vocoder.configs import MultibandMelganConfig
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
|
from TTS.vocoder.models.gan import GAN
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
config = MultibandMelganConfig(
|
config = MultibandMelganConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
num_loader_workers=4,
|
num_loader_workers=4,
|
||||||
num_eval_loader_workers=4,
|
num_eval_loader_workers=4,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=5,
|
||||||
epochs=1000,
|
epochs=1000,
|
||||||
seq_len=8192,
|
seq_len=8192,
|
||||||
pad_short=2000,
|
pad_short=2000,
|
||||||
use_noise_augment=True,
|
use_noise_augment=True,
|
||||||
eval_split_size=10,
|
eval_split_size=10,
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
lr_gen=1e-4,
|
lr_gen=1e-4,
|
||||||
lr_disc=1e-4,
|
lr_disc=1e-4,
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = GAN(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.configs import UnivnetConfig
|
from TTS.vocoder.configs import UnivnetConfig
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
|
from TTS.vocoder.models.gan import GAN
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
config = UnivnetConfig(
|
config = UnivnetConfig(
|
||||||
|
@ -24,6 +27,24 @@ config = UnivnetConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = GAN(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue