add Coqpits for the vocoder models

This commit is contained in:
Eren Gölge 2021-05-07 15:33:52 +02:00
parent 6f4eed94f5
commit 3dec62b183
9 changed files with 532 additions and 0 deletions

View File

@ -0,0 +1,17 @@
import importlib
import os
from inspect import isclass
# import all files under configs/
configs_dir = os.path.dirname(__file__)
for file in os.listdir(configs_dir):
path = os.path.join(configs_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
config_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("TTS.vocoder.configs." + config_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute

View File

@ -0,0 +1,54 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class FullbandMelganConfig(BaseGANVocoderConfig):
"""Defines parameters for FullbandMelGAN vocoder."""
model: str = "melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {
"base_channels": 16,
"max_channels": 512,
"downsample_factors": [4, 4, 4]
})
generator_model: str = "melgan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors": [8, 8, 2, 2],
"num_res_blocks": 4
})
# Training - overrides
batch_size: int = 16
seq_len: int = 8192
pad_short: int = 2000
use_noise_augment: bool = True
use_cache: bool = True
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
})
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0

View File

@ -0,0 +1,53 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class HifiganConfig(BaseGANVocoderConfig):
"""Defines parameters for HifiGAN vocoder."""
model: str = "hifigan"
# model specific params
discriminator_model: str = "hifigan_discriminator"
generator_model: str = "hifigan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_type": "1"
})
# LOSS PARAMETERS - overrides
use_stft_loss: bool = False
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = True
# loss weights - overrides
stft_loss_weight: float = 0
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 1
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 45
l1_spec_loss_params: dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None
})
# optimizer parameters
lr: float = 1e-4
wd: float = 1e-6

View File

@ -0,0 +1,54 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class MelganConfig(BaseGANVocoderConfig):
"""Defines parameters for MelGAN vocoder."""
model: str = "melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {
"base_channels": 16,
"max_channels": 1024,
"downsample_factors": [4, 4, 4, 4]
})
generator_model: str = "melgan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors": [8, 8, 2, 2],
"num_res_blocks": 3
})
# Training - overrides
batch_size: int = 16
seq_len: int = 8192
pad_short: int = 2000
use_noise_augment: bool = True
use_cache: bool = True
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
})
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0

View File

@ -0,0 +1,79 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class MultibandMelganConfig(BaseGANVocoderConfig):
"""Defines parameters for MultiBandMelGAN vocoder."""
model: str = "multiband_melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {
"base_channels": 16,
"max_channels": 512,
"downsample_factors": [4, 4, 4]
})
generator_model: str = "multiband_melgan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors": [8, 4, 2],
"num_res_blocks": 4
})
use_pqmf: bool = True
# optimizer - overrides
lr_gen: float = 0.0001 # Initial learning rate.
lr_disc: float = 0.0001 # Initial learning rate.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {
"betas": [0.8, 0.99],
"weight_decay": 0.0
})
lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {
"gamma": 0.5,
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
})
lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {
"gamma": 0.5,
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
})
# Training - overrides
batch_size: int = 64
seq_len: int = 16384
pad_short: int = 2000
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: bool = 200000
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = True
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
subband_stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [384, 683, 171],
"hop_lengths": [30, 60, 10],
"win_lengths": [150, 300, 60]
})
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0
# optimizer parameters
lr: float = 1e-4
wd: float = 1e-6

View File

@ -0,0 +1,73 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class ParallelWaveganConfig(BaseGANVocoderConfig):
"""Defines parameters for ParallelWavegan vocoder."""
model: str = "parallel_wavegan"
# Model specific params
discriminator_model: str = "parallel_wavegan_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {
"num_layers": 10
})
generator_model: str = "parallel_wavegan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors":[4, 4, 4, 4],
"stacks": 3,
"num_res_blocks": 30
})
# Training - overrides
batch_size: int = 6
seq_len: int = 25600
pad_short: int = 2000
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: int = 200000
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
})
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0
# optimizer overrides
lr_gen: float = 0.0002 # Initial learning rate.
lr_disc: float = 0.0002 # Initial learning rate.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {
"betas": [0.8, 0.99],
"weight_decay": 0.0
})
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {
"gamma": 0.999,
"last_epoch": -1
})
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {
"gamma": 0.999,
"last_epoch": -1
})

View File

@ -0,0 +1,92 @@
from dataclasses import dataclass, field
from typing import List
from coqpit import MISSING
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
@dataclass
class BaseVocoderConfig(BaseTrainingConfig):
"""Shared parameters among all the vocoder models."""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# dataloading
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
eval_split_size: int = 10 # number of samples used for evaluation.
# dataset
data_path: str = MISSING # root data path. It finds all wav files recursively from there.
feature_path: str = None # if you use precomputed features
seq_len: int = MISSING # signal length used in training.
pad_short: int = 0 # additional padding for short wavs
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
use_noise_augment: bool = False # add noise to the audio signal for augmentation
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
# OPTIMIZER
epochs: int = 10000 # total number of epochs to train.
wd: float = 0.0 # Weight decay weight.
@dataclass
class BaseGANVocoderConfig(BaseVocoderConfig):
"""Common config interface for all the GAN based vocoder models."""
# LOSS PARAMETERS
use_stft_loss: bool = True
use_subband_stft_loss: bool = True
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = True
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = True
# loss weights
stft_loss_weight: float = 0
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 1
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 10
l1_spec_loss_weight: float = 45
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
})
l1_spec_loss_params: dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None
})
target_loss: str = "avg_G_loss" # loss value to pick the best model to save after each epoch
# optimizer
gen_clip_grad: float = -1 # Generator gradient clipping threshold. Apply gradient clipping if > 0
disc_clip_grad: float = -1 # Discriminator gradient clipping threshold.
lr_gen: float = 0.0002 # Initial learning rate.
lr_disc: float = 0.0002 # Initial learning rate.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {
"betas": [0.8, 0.99],
"weight_decay": 0.0
})
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {
"gamma": 0.999,
"last_epoch": -1
})
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {
"gamma": 0.999,
"last_epoch": -1
})
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps.

View File

@ -0,0 +1,58 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseVocoderConfig
@dataclass
class WavegradConfig(BaseVocoderConfig):
"""Defines parameters for Wavernn vocoder."""
model: str = 'wavegrad'
# Model specific params
generator_model: str = "wavegrad"
model_params: dict = field(
default_factory=lambda: {
"use_weight_norm":
True,
"y_conv_channels":
32,
"x_conv_channels":
768,
"ublock_out_channels": [512, 512, 256, 128, 128],
"dblock_out_channels": [128, 128, 256, 512],
"upsample_factors": [4, 4, 4, 2, 2],
"upsample_dilations": [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8],
[1, 2, 4, 8], [1, 2, 4, 8]]
})
target_loss: str = 'avg_wavegrad_loss' # loss value to pick the best model to save after each epoch
# Training - overrides
epochs: int = 10000
batch_size: int = 96
seq_len: int = 6144
use_cache: bool = True
steps_to_start_discriminator: int = 200000
mixed_precision: bool = True
eval_split_size: int = 50
# NOISE SCHEDULE PARAMS
train_noise_schedule: dict = field(default_factory=lambda: {
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 1000
})
test_noise_schedule: dict = field(default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values.
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 50
})
# optimizer overrides
grad_clip: float = 1.0
lr: float = 1e-4 # Initial learning rate.
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_params: dict = field(
default_factory=lambda: {
"gamma": 0.5,
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
})

View File

@ -0,0 +1,52 @@
from dataclasses import asdict, dataclass, field
from .shared_configs import BaseVocoderConfig
@dataclass
class WavernnConfig(BaseVocoderConfig):
"""Defines parameters for Wavernn vocoder."""
model: str = "wavernn"
# Model specific params
mode: str = 'mold' # mold [string], gauss [string], bits [int]
mulaw: bool = True # apply mulaw if mode is bits
generator_model: str = "WaveRNN"
wavernn_model_params: dict = field(
default_factory=lambda: {
"rnn_dims": 512,
"fc_dims": 512,
"compute_dims": 128,
"res_out_dims": 128,
"num_res_blocks": 10,
"use_aux_net": True,
"use_upsample_net": True,
"upsample_factors":
[4, 8, 8] # this needs to correctly factorise hop_length
})
# Inference
batched: bool = True
target_samples: int = 11000
overlap_samples: int = 550
# Training - overrides
epochs: int = 10000
batch_size: int = 256
seq_len: int = 1280
padding: int = 2
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: int = 200000
mixed_precision: bool = True
eval_split_size: int = 50
test_every_epochs: int = 10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
# optimizer overrides
grad_clip: float = 4.0
lr: float = 1e-4 # Initial learning rate.
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_params: dict = field(default_factory=lambda: {
"gamma": 0.5,
"milestones": [200000, 400000, 600000]
})