Update WaveRNN

This commit is contained in:
Eren Gölge 2021-09-30 14:21:35 +00:00
parent 3d5205d66f
commit 4f94f91305
2 changed files with 33 additions and 12 deletions

View File

@ -222,10 +222,7 @@ class Wavernn(BaseVocoder):
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
orthogonal method for increasing sampling efficiency.
"""
super().__init__()
self.args = config.model_params
self.config = config
super().__init__(config)
if isinstance(self.args.mode, int):
self.n_classes = 2 ** self.args.mode
@ -572,8 +569,9 @@ class Wavernn(BaseVocoder):
@torch.no_grad()
def test_run(
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]:
ap = assets["audio_processor"]
figures = {}
audios = {}
for idx, sample in enumerate(samples):
@ -600,20 +598,21 @@ class Wavernn(BaseVocoder):
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
ap: AudioProcessor,
assets: Dict,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int,
):
ap = assets["audio_processor"]
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
pad=config.model_args.pad,
mode=config.model_args.mode,
mulaw=config.model_args.mulaw,
is_training=not is_eval,
verbose=verbose,
)

View File

@ -1,7 +1,11 @@
import os
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import WavernnConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.wavernn import Wavernn
output_path = os.path.dirname(os.path.abspath(__file__))
config = WavernnConfig(
@ -24,6 +28,24 @@ config = WavernnConfig(
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = Wavernn(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()