mirror of https://github.com/coqui-ai/TTS.git
Fix WaveRNN config and test
This commit is contained in:
parent
55d9209221
commit
7edbe04fe0
|
@ -17,11 +17,11 @@ class BaseVocoderConfig(BaseTrainingConfig):
|
||||||
Number of instances used for evaluation. Defaults to 10.
|
Number of instances used for evaluation. Defaults to 10.
|
||||||
data_path (str):
|
data_path (str):
|
||||||
Root path of the training data. All the audio files found recursively from this root path are used for
|
Root path of the training data. All the audio files found recursively from this root path are used for
|
||||||
training. Defaults to MISSING.
|
training. Defaults to `""`.
|
||||||
feature_path (str):
|
feature_path (str):
|
||||||
Root path to the precomputed feature files. Defaults to None.
|
Root path to the precomputed feature files. Defaults to None.
|
||||||
seq_len (int):
|
seq_len (int):
|
||||||
Length of the waveform segments used for training. Defaults to MISSING.
|
Length of the waveform segments used for training. Defaults to 1000.
|
||||||
pad_short (int):
|
pad_short (int):
|
||||||
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
|
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
|
||||||
conv_path (int):
|
conv_path (int):
|
||||||
|
@ -45,9 +45,9 @@ class BaseVocoderConfig(BaseTrainingConfig):
|
||||||
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
|
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
|
||||||
eval_split_size: int = 10 # number of samples used for evaluation.
|
eval_split_size: int = 10 # number of samples used for evaluation.
|
||||||
# dataset
|
# dataset
|
||||||
data_path: str = MISSING # root data path. It finds all wav files recursively from there.
|
data_path: str = "" # root data path. It finds all wav files recursively from there.
|
||||||
feature_path: str = None # if you use precomputed features
|
feature_path: str = None # if you use precomputed features
|
||||||
seq_len: int = MISSING # signal length used in training.
|
seq_len: int = 1000 # signal length used in training.
|
||||||
pad_short: int = 0 # additional padding for short wavs
|
pad_short: int = 0 # additional padding for short wavs
|
||||||
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
|
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
|
||||||
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
|
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
|
||||||
|
|
|
@ -12,7 +12,7 @@ def test_wavernn():
|
||||||
config.model_args = WavernnArgs(
|
config.model_args = WavernnArgs(
|
||||||
rnn_dims=512,
|
rnn_dims=512,
|
||||||
fc_dims=512,
|
fc_dims=512,
|
||||||
mode=10,
|
mode="mold",
|
||||||
mulaw=False,
|
mulaw=False,
|
||||||
pad=2,
|
pad=2,
|
||||||
use_aux_net=True,
|
use_aux_net=True,
|
||||||
|
@ -37,13 +37,13 @@ def test_wavernn():
|
||||||
assert np.all(output.shape == (2, 1280, 30)), output.shape
|
assert np.all(output.shape == (2, 1280, 30)), output.shape
|
||||||
|
|
||||||
# mode: gauss
|
# mode: gauss
|
||||||
config.model_params.mode = "gauss"
|
config.model_args.mode = "gauss"
|
||||||
model = Wavernn(config)
|
model = Wavernn(config)
|
||||||
output = model(dummy_x, dummy_m)
|
output = model(dummy_x, dummy_m)
|
||||||
assert np.all(output.shape == (2, 1280, 2)), output.shape
|
assert np.all(output.shape == (2, 1280, 2)), output.shape
|
||||||
|
|
||||||
# mode: quantized
|
# mode: quantized
|
||||||
config.model_params.mode = 4
|
config.model_args.mode = 4
|
||||||
model = Wavernn(config)
|
model = Wavernn(config)
|
||||||
output = model(dummy_x, dummy_m)
|
output = model(dummy_x, dummy_m)
|
||||||
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape
|
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape
|
||||||
|
|
Loading…
Reference in New Issue