Implement VitsAudioConfig (#1556)

* Implement VitsAudioConfig

* Update VITS LJSpeech recipe

* Update VITS VCTK recipe

* Make style

* Add missing decorator

* Add missing param

* Make style

* Update recipes

* Fix test

* Bug fix

* Exclude tests folder

* Make linter

* Make style
This commit is contained in:
Eren Gölge 2022-07-12 18:49:58 +02:00 committed by GitHub
parent 34b80e0280
commit 49bac724c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 65 additions and 76 deletions

View File

@ -11,4 +11,5 @@ recursive-include TTS *.md
recursive-include TTS *.py recursive-include TTS *.py
recursive-include TTS *.pyx recursive-include TTS *.pyx
recursive-include images *.png recursive-include images *.png
recursive-exclude tests *
prune tests*

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
from typing import List from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
@dataclass @dataclass
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
model_args (VitsArgs): model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`. Model architecture arguments. Defaults to `VitsArgs()`.
audio (VitsAudioConfig):
Audio processing configuration. Defaults to `VitsAudioConfig()`.
grad_clip (List): grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
model: str = "vits" model: str = "vits"
# model specific params # model specific params
model_args: VitsArgs = field(default_factory=VitsArgs) model_args: VitsArgs = field(default_factory=VitsArgs)
audio: VitsAudioConfig = VitsAudioConfig()
# optimizer # optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])

View File

@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
return spec return spec
#############################
# CONFIGS
#############################
@dataclass
class VitsAudioConfig(Coqpit):
fft_size: int = 1024
sample_rate: int = 22050
win_length: int = 1024
hop_length: int = 256
num_mels: int = 80
mel_fmin: int = 0
mel_fmax: int = None
############################## ##############################
# DATASET # DATASET
############################## ##############################

View File

@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
""" """
if reduction == "none": if reduction == "none":
return x return x
elif reduction == "mean": if reduction == "mean":
return x.mean(dim=0) return x.mean(dim=0)
elif reduction == "sum": if reduction == "sum":
return x.sum(dim=0) return x.sum(dim=0)
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")

View File

@ -307,7 +307,7 @@ class Synthesizer(object):
waveform = waveform.squeeze() waveform = waveform.squeeze()
# trim silence # trim silence
if self.tts_config.audio["do_trim_silence"] is True: if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
waveform = trim_silence(waveform, self.tts_model.ap) waveform = trim_silence(waveform, self.tts_model.ap)
wavs += list(waveform) wavs += list(waveform)

View File

@ -54,7 +54,6 @@ config = FastPitchConfig(
print_step=50, print_step=50,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000, max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],

View File

@ -53,7 +53,6 @@ config = FastSpeechConfig(
print_step=50, print_step=50,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000, max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],

View File

@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
print_step=50, print_step=50,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000, max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],

View File

@ -68,7 +68,6 @@ config = Tacotron2Config(
print_step=25, print_step=25,
print_eval=True, print_eval=True,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
seq_len_norm=True, seq_len_norm=True,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig( dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
) )
audio_config = BaseAudioConfig( audio_config = VitsAudioConfig(
sample_rate=22050, sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
) )
config = VitsConfig( config = VitsConfig(
@ -37,7 +23,7 @@ config = VitsConfig(
batch_size=32, batch_size=32,
eval_batch_size=16, eval_batch_size=16,
batch_group_size=5, batch_group_size=5,
num_loader_workers=0, num_loader_workers=8,
num_eval_loader_workers=4, num_eval_loader_workers=4,
run_eval=True, run_eval=True,
test_delay_epochs=-1, test_delay_epochs=-1,
@ -52,6 +38,7 @@ config = VitsConfig(
mixed_precision=True, mixed_precision=True,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
cudnn_benchmark=False,
) )
# INITIALIZE THE AUDIO PROCESSOR # INITIALIZE THE AUDIO PROCESSOR

View File

@ -3,11 +3,10 @@ from glob import glob
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -22,22 +21,13 @@ dataset_config = [
for path in dataset_paths for path in dataset_paths
] ]
audio_config = BaseAudioConfig( audio_config = VitsAudioConfig(
sample_rate=16000, sample_rate=16000,
win_length=1024, win_length=1024,
hop_length=256, hop_length=256,
num_mels=80, num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=False,
trim_db=23.0,
mel_fmin=0, mel_fmin=0,
mel_fmax=None, mel_fmax=None,
spec_gain=1.0,
signal_norm=True,
do_amp_to_db_linear=False,
resample=False,
) )
vitsArgs = VitsArgs( vitsArgs = VitsArgs(
@ -69,7 +59,6 @@ config = VitsConfig(
use_language_weighted_sampler=True, use_language_weighted_sampler=True,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
min_audio_len=32 * 256 * 4, min_audio_len=32 * 256 * 4,
max_audio_len=160000, max_audio_len=160000,
output_path=output_path, output_path=output_path,

View File

@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.", "Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
"Vor dem 22. November 1963.", "Vor dem 22. November 1963.",
], ],
sort_by_audio_len=True,
max_seq_len=500000, max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.downloaders import download_thorsten_de from TTS.utils.downloaders import download_thorsten_de
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
print("Downloading dataset") print("Downloading dataset")
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0]) download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
audio_config = BaseAudioConfig( audio_config = VitsAudioConfig(
sample_rate=22050, sample_rate=22050,
win_length=1024, win_length=1024,
hop_length=256, hop_length=256,
num_mels=80, num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0, mel_fmin=0,
mel_fmax=None, mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
) )
config = VitsConfig( config = VitsConfig(

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
) )
audio_config = BaseAudioConfig( audio_config = VitsAudioConfig(
sample_rate=22050, sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
resample=True,
) )
vitsArgs = VitsArgs( vitsArgs = VitsArgs(
@ -62,6 +47,7 @@ config = VitsConfig(
max_text_len=325, # change this if you have a larger VRAM than 16GB max_text_len=325, # change this if you have a larger VRAM than 16GB
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
cudnn_benchmark=False,
) )
# INITIALIZE THE AUDIO PROCESSOR # INITIALIZE THE AUDIO PROCESSOR

View File

@ -90,7 +90,7 @@ setup(
# ext_modules=find_cython_extensions(), # ext_modules=find_cython_extensions(),
# package # package
include_package_data=True, include_package_data=True,
packages=find_packages(include=["TTS*"]), packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
package_data={ package_data={
"TTS": [ "TTS": [
"VERSION", "VERSION",

View File

@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.models.vits import (
Vits,
VitsArgs,
VitsAudioConfig,
amp_to_db,
db_to_amp,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
self._check_parameter_changes(model, model_ref) self._check_parameter_changes(model, model_ref)
def test_train_step_upsampling(self): def test_train_step_upsampling(self):
"""Upsampling by the decoder upsampling layers"""
# setup the model # setup the model
with torch.autograd.set_detect_anomaly(True): with torch.autograd.set_detect_anomaly(True):
audio_config = VitsAudioConfig(sample_rate=22050)
model_args = VitsArgs( model_args = VitsArgs(
num_chars=32, num_chars=32,
spec_segment_size=10, spec_segment_size=10,
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
interpolate_z=False, interpolate_z=False,
upsample_rates_decoder=[8, 8, 4, 2], upsample_rates_decoder=[8, 8, 4, 2],
) )
config = VitsConfig(model_args=model_args) config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device) model = Vits(config).to(device)
model.train() model.train()
# model to train # model to train
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
self._check_parameter_changes(model, model_ref) self._check_parameter_changes(model, model_ref)
def test_train_step_upsampling_interpolation(self): def test_train_step_upsampling_interpolation(self):
"""Upsampling by interpolation"""
# setup the model # setup the model
with torch.autograd.set_detect_anomaly(True): with torch.autograd.set_detect_anomaly(True):
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True) audio_config = VitsAudioConfig(sample_rate=22050)
config = VitsConfig(model_args=model_args) model_args = VitsArgs(
num_chars=32,
spec_segment_size=10,
encoder_sample_rate=11025,
interpolate_z=True,
upsample_rates_decoder=[8, 8, 2, 2],
)
config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device) model = Vits(config).to(device)
model.train() model.train()
# model to train # model to train