diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 9b0d6837..1977efb6 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -222,10 +222,7 @@ class Wavernn(BaseVocoder): samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an orthogonal method for increasing sampling efficiency. """ - super().__init__() - - self.args = config.model_params - self.config = config + super().__init__(config) if isinstance(self.args.mode, int): self.n_classes = 2 ** self.args.mode @@ -572,8 +569,9 @@ class Wavernn(BaseVocoder): @torch.no_grad() def test_run( - self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument + self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: + ap = assets["audio_processor"] figures = {} audios = {} for idx, sample in enumerate(samples): @@ -600,20 +598,21 @@ class Wavernn(BaseVocoder): def get_data_loader( # pylint: disable=no-self-use self, config: Coqpit, - ap: AudioProcessor, + assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int, ): + ap = assets["audio_processor"] dataset = WaveRNNDataset( ap=ap, items=data_items, seq_len=config.seq_len, hop_len=ap.hop_length, - pad=config.model_params.pad, - mode=config.model_params.mode, - mulaw=config.model_params.mulaw, + pad=config.model_args.pad, + mode=config.model_args.mode, + mulaw=config.model_args.mulaw, is_training=not is_eval, verbose=verbose, ) diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 8f138298..9777a985 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -1,7 +1,11 @@ import os -from TTS.trainer import Trainer, TrainingArgs, init_training + +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor from TTS.vocoder.configs import WavernnConfig +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.models.wavernn import Wavernn output_path = os.path.dirname(os.path.abspath(__file__)) config = WavernnConfig( @@ -24,6 +28,24 @@ config = WavernnConfig( data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) -args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True) + +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + +# init model +model = Wavernn(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()