mirror of https://github.com/coqui-ai/TTS.git
Implement VitsAudioConfig (#1556)
* Implement VitsAudioConfig * Update VITS LJSpeech recipe * Update VITS VCTK recipe * Make style * Add missing decorator * Add missing param * Make style * Update recipes * Fix test * Bug fix * Exclude tests folder * Make linter * Make style
This commit is contained in:
parent
34b80e0280
commit
49bac724c0
|
@ -11,4 +11,5 @@ recursive-include TTS *.md
|
||||||
recursive-include TTS *.py
|
recursive-include TTS *.py
|
||||||
recursive-include TTS *.pyx
|
recursive-include TTS *.pyx
|
||||||
recursive-include images *.png
|
recursive-include images *.png
|
||||||
|
recursive-exclude tests *
|
||||||
|
prune tests*
|
||||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.models.vits import VitsArgs
|
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
|
||||||
model_args (VitsArgs):
|
model_args (VitsArgs):
|
||||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||||
|
|
||||||
|
audio (VitsAudioConfig):
|
||||||
|
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
||||||
|
|
||||||
grad_clip (List):
|
grad_clip (List):
|
||||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||||
|
|
||||||
|
@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
model: str = "vits"
|
model: str = "vits"
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||||
|
audio: VitsAudioConfig = VitsAudioConfig()
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||||
|
|
|
@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# CONFIGS
|
||||||
|
#############################
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsAudioConfig(Coqpit):
|
||||||
|
fft_size: int = 1024
|
||||||
|
sample_rate: int = 22050
|
||||||
|
win_length: int = 1024
|
||||||
|
hop_length: int = 256
|
||||||
|
num_mels: int = 80
|
||||||
|
mel_fmin: int = 0
|
||||||
|
mel_fmax: int = None
|
||||||
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# DATASET
|
# DATASET
|
||||||
##############################
|
##############################
|
||||||
|
|
|
@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
if reduction == "none":
|
if reduction == "none":
|
||||||
return x
|
return x
|
||||||
elif reduction == "mean":
|
if reduction == "mean":
|
||||||
return x.mean(dim=0)
|
return x.mean(dim=0)
|
||||||
elif reduction == "sum":
|
if reduction == "sum":
|
||||||
return x.sum(dim=0)
|
return x.sum(dim=0)
|
||||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||||
|
|
||||||
|
|
|
@ -307,7 +307,7 @@ class Synthesizer(object):
|
||||||
waveform = waveform.squeeze()
|
waveform = waveform.squeeze()
|
||||||
|
|
||||||
# trim silence
|
# trim silence
|
||||||
if self.tts_config.audio["do_trim_silence"] is True:
|
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||||
|
|
||||||
wavs += list(waveform)
|
wavs += list(waveform)
|
||||||
|
|
|
@ -54,7 +54,6 @@ config = FastPitchConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -53,7 +53,6 @@ config = FastSpeechConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -68,7 +68,6 @@ config = Tacotron2Config(
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
seq_len_norm=True,
|
seq_len_norm=True,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||||
win_length=1024,
|
|
||||||
hop_length=256,
|
|
||||||
num_mels=80,
|
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=45,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
@ -37,7 +23,7 @@ config = VitsConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
batch_group_size=5,
|
batch_group_size=5,
|
||||||
num_loader_workers=0,
|
num_loader_workers=8,
|
||||||
num_eval_loader_workers=4,
|
num_eval_loader_workers=4,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=-1,
|
||||||
|
@ -52,6 +38,7 @@ config = VitsConfig(
|
||||||
mixed_precision=True,
|
mixed_precision=True,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
cudnn_benchmark=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# INITIALIZE THE AUDIO PROCESSOR
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
|
|
|
@ -3,11 +3,10 @@ from glob import glob
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
|
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
@ -22,22 +21,13 @@ dataset_config = [
|
||||||
for path in dataset_paths
|
for path in dataset_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
win_length=1024,
|
win_length=1024,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
num_mels=80,
|
num_mels=80,
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=False,
|
|
||||||
trim_db=23.0,
|
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=True,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
resample=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vitsArgs = VitsArgs(
|
vitsArgs = VitsArgs(
|
||||||
|
@ -69,7 +59,6 @@ config = VitsConfig(
|
||||||
use_language_weighted_sampler=True,
|
use_language_weighted_sampler=True,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
min_audio_len=32 * 256 * 4,
|
min_audio_len=32 * 256 * 4,
|
||||||
max_audio_len=160000,
|
max_audio_len=160000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
|
|
@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
|
||||||
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
||||||
"Vor dem 22. November 1963.",
|
"Vor dem 22. November 1963.",
|
||||||
],
|
],
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.downloaders import download_thorsten_de
|
from TTS.utils.downloaders import download_thorsten_de
|
||||||
|
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
|
||||||
print("Downloading dataset")
|
print("Downloading dataset")
|
||||||
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
win_length=1024,
|
win_length=1024,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
num_mels=80,
|
num_mels=80,
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=45,
|
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||||
win_length=1024,
|
|
||||||
hop_length=256,
|
|
||||||
num_mels=80,
|
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=23.0,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
resample=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vitsArgs = VitsArgs(
|
vitsArgs = VitsArgs(
|
||||||
|
@ -62,6 +47,7 @@ config = VitsConfig(
|
||||||
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
cudnn_benchmark=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# INITIALIZE THE AUDIO PROCESSOR
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -90,7 +90,7 @@ setup(
|
||||||
# ext_modules=find_cython_extensions(),
|
# ext_modules=find_cython_extensions(),
|
||||||
# package
|
# package
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
packages=find_packages(include=["TTS*"]),
|
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
|
||||||
package_data={
|
package_data={
|
||||||
"TTS": [
|
"TTS": [
|
||||||
"VERSION",
|
"VERSION",
|
||||||
|
|
|
@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
from TTS.tts.models.vits import (
|
||||||
|
Vits,
|
||||||
|
VitsArgs,
|
||||||
|
VitsAudioConfig,
|
||||||
|
amp_to_db,
|
||||||
|
db_to_amp,
|
||||||
|
load_audio,
|
||||||
|
spec_to_mel,
|
||||||
|
wav_to_mel,
|
||||||
|
wav_to_spec,
|
||||||
|
)
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
|
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
|
||||||
self._check_parameter_changes(model, model_ref)
|
self._check_parameter_changes(model, model_ref)
|
||||||
|
|
||||||
def test_train_step_upsampling(self):
|
def test_train_step_upsampling(self):
|
||||||
|
"""Upsampling by the decoder upsampling layers"""
|
||||||
# setup the model
|
# setup the model
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||||
model_args = VitsArgs(
|
model_args = VitsArgs(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
spec_segment_size=10,
|
spec_segment_size=10,
|
||||||
|
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
|
||||||
interpolate_z=False,
|
interpolate_z=False,
|
||||||
upsample_rates_decoder=[8, 8, 4, 2],
|
upsample_rates_decoder=[8, 8, 4, 2],
|
||||||
)
|
)
|
||||||
config = VitsConfig(model_args=model_args)
|
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
# model to train
|
# model to train
|
||||||
|
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
|
||||||
self._check_parameter_changes(model, model_ref)
|
self._check_parameter_changes(model, model_ref)
|
||||||
|
|
||||||
def test_train_step_upsampling_interpolation(self):
|
def test_train_step_upsampling_interpolation(self):
|
||||||
|
"""Upsampling by interpolation"""
|
||||||
# setup the model
|
# setup the model
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
|
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||||
config = VitsConfig(model_args=model_args)
|
model_args = VitsArgs(
|
||||||
|
num_chars=32,
|
||||||
|
spec_segment_size=10,
|
||||||
|
encoder_sample_rate=11025,
|
||||||
|
interpolate_z=True,
|
||||||
|
upsample_rates_decoder=[8, 8, 2, 2],
|
||||||
|
)
|
||||||
|
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
# model to train
|
# model to train
|
||||||
|
|
Loading…
Reference in New Issue