Update WaveRNN

This commit is contained in:
Eren Gölge 2021-09-30 14:21:35 +00:00
parent 3d5205d66f
commit 4f94f91305
2 changed files with 33 additions and 12 deletions

View File

@ -222,10 +222,7 @@ class Wavernn(BaseVocoder):
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
orthogonal method for increasing sampling efficiency. orthogonal method for increasing sampling efficiency.
""" """
super().__init__() super().__init__(config)
self.args = config.model_params
self.config = config
if isinstance(self.args.mode, int): if isinstance(self.args.mode, int):
self.n_classes = 2 ** self.args.mode self.n_classes = 2 ** self.args.mode
@ -572,8 +569,9 @@ class Wavernn(BaseVocoder):
@torch.no_grad() @torch.no_grad()
def test_run( def test_run(
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
ap = assets["audio_processor"]
figures = {} figures = {}
audios = {} audios = {}
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
@ -600,20 +598,21 @@ class Wavernn(BaseVocoder):
def get_data_loader( # pylint: disable=no-self-use def get_data_loader( # pylint: disable=no-self-use
self, self,
config: Coqpit, config: Coqpit,
ap: AudioProcessor, assets: Dict,
is_eval: True, is_eval: True,
data_items: List, data_items: List,
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
): ):
ap = assets["audio_processor"]
dataset = WaveRNNDataset( dataset = WaveRNNDataset(
ap=ap, ap=ap,
items=data_items, items=data_items,
seq_len=config.seq_len, seq_len=config.seq_len,
hop_len=ap.hop_length, hop_len=ap.hop_length,
pad=config.model_params.pad, pad=config.model_args.pad,
mode=config.model_params.mode, mode=config.model_args.mode,
mulaw=config.model_params.mulaw, mulaw=config.model_args.mulaw,
is_training=not is_eval, is_training=not is_eval,
verbose=verbose, verbose=verbose,
) )

View File

@ -1,7 +1,11 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.configs import WavernnConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.wavernn import Wavernn
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
config = WavernnConfig( config = WavernnConfig(
@ -24,6 +28,24 @@ config = WavernnConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path, output_path=output_path,
) )
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True) # init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = Wavernn(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit() trainer.fit()