Fix WaveRNN config and test

This commit is contained in:
Eren Gölge 2021-09-30 16:20:12 +00:00
parent 55d9209221
commit 7edbe04fe0
2 changed files with 7 additions and 7 deletions

View File

@ -17,11 +17,11 @@ class BaseVocoderConfig(BaseTrainingConfig):
Number of instances used for evaluation. Defaults to 10. Number of instances used for evaluation. Defaults to 10.
data_path (str): data_path (str):
Root path of the training data. All the audio files found recursively from this root path are used for Root path of the training data. All the audio files found recursively from this root path are used for
training. Defaults to MISSING. training. Defaults to `""`.
feature_path (str): feature_path (str):
Root path to the precomputed feature files. Defaults to None. Root path to the precomputed feature files. Defaults to None.
seq_len (int): seq_len (int):
Length of the waveform segments used for training. Defaults to MISSING. Length of the waveform segments used for training. Defaults to 1000.
pad_short (int): pad_short (int):
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0. Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
conv_path (int): conv_path (int):
@ -45,9 +45,9 @@ class BaseVocoderConfig(BaseTrainingConfig):
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms. use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
eval_split_size: int = 10 # number of samples used for evaluation. eval_split_size: int = 10 # number of samples used for evaluation.
# dataset # dataset
data_path: str = MISSING # root data path. It finds all wav files recursively from there. data_path: str = "" # root data path. It finds all wav files recursively from there.
feature_path: str = None # if you use precomputed features feature_path: str = None # if you use precomputed features
seq_len: int = MISSING # signal length used in training. seq_len: int = 1000 # signal length used in training.
pad_short: int = 0 # additional padding for short wavs pad_short: int = 0 # additional padding for short wavs
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM. use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.

View File

@ -12,7 +12,7 @@ def test_wavernn():
config.model_args = WavernnArgs( config.model_args = WavernnArgs(
rnn_dims=512, rnn_dims=512,
fc_dims=512, fc_dims=512,
mode=10, mode="mold",
mulaw=False, mulaw=False,
pad=2, pad=2,
use_aux_net=True, use_aux_net=True,
@ -37,13 +37,13 @@ def test_wavernn():
assert np.all(output.shape == (2, 1280, 30)), output.shape assert np.all(output.shape == (2, 1280, 30)), output.shape
# mode: gauss # mode: gauss
config.model_params.mode = "gauss" config.model_args.mode = "gauss"
model = Wavernn(config) model = Wavernn(config)
output = model(dummy_x, dummy_m) output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2)), output.shape assert np.all(output.shape == (2, 1280, 2)), output.shape
# mode: quantized # mode: quantized
config.model_params.mode = 4 config.model_args.mode = 4
model = Wavernn(config) model = Wavernn(config)
output = model(dummy_x, dummy_m) output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape