mirror of https://github.com/coqui-ai/TTS.git
Update WaveRNN
This commit is contained in:
parent
3d5205d66f
commit
4f94f91305
|
@ -222,10 +222,7 @@ class Wavernn(BaseVocoder):
|
||||||
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
||||||
orthogonal method for increasing sampling efficiency.
|
orthogonal method for increasing sampling efficiency.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.args = config.model_params
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if isinstance(self.args.mode, int):
|
if isinstance(self.args.mode, int):
|
||||||
self.n_classes = 2 ** self.args.mode
|
self.n_classes = 2 ** self.args.mode
|
||||||
|
@ -572,8 +569,9 @@ class Wavernn(BaseVocoder):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(
|
def test_run(
|
||||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
|
self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
|
ap = assets["audio_processor"]
|
||||||
figures = {}
|
figures = {}
|
||||||
audios = {}
|
audios = {}
|
||||||
for idx, sample in enumerate(samples):
|
for idx, sample in enumerate(samples):
|
||||||
|
@ -600,20 +598,21 @@ class Wavernn(BaseVocoder):
|
||||||
def get_data_loader( # pylint: disable=no-self-use
|
def get_data_loader( # pylint: disable=no-self-use
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
ap: AudioProcessor,
|
assets: Dict,
|
||||||
is_eval: True,
|
is_eval: True,
|
||||||
data_items: List,
|
data_items: List,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
):
|
):
|
||||||
|
ap = assets["audio_processor"]
|
||||||
dataset = WaveRNNDataset(
|
dataset = WaveRNNDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
items=data_items,
|
items=data_items,
|
||||||
seq_len=config.seq_len,
|
seq_len=config.seq_len,
|
||||||
hop_len=ap.hop_length,
|
hop_len=ap.hop_length,
|
||||||
pad=config.model_params.pad,
|
pad=config.model_args.pad,
|
||||||
mode=config.model_params.mode,
|
mode=config.model_args.mode,
|
||||||
mulaw=config.model_params.mulaw,
|
mulaw=config.model_args.mulaw,
|
||||||
is_training=not is_eval,
|
is_training=not is_eval,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.configs import WavernnConfig
|
from TTS.vocoder.configs import WavernnConfig
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
|
from TTS.vocoder.models.wavernn import Wavernn
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
config = WavernnConfig(
|
config = WavernnConfig(
|
||||||
|
@ -24,6 +28,24 @@ config = WavernnConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = Wavernn(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue