refactor(api): require keyword arguments except for model_name

This commit is contained in:
Enno Hermann 2024-12-03 22:09:03 +01:00
parent 8c381e3e48
commit 5cfb4ecccd
1 changed files with 7 additions and 5 deletions

View File

@ -2,6 +2,7 @@ import logging
import tempfile import tempfile
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
from torch import nn from torch import nn
@ -19,12 +20,13 @@ class TTS(nn.Module):
def __init__( def __init__(
self, self,
model_name: str = "", model_name: str = "",
model_path: str = None, *,
config_path: str = None, model_path: Optional[str] = None,
vocoder_path: str = None, config_path: Optional[str] = None,
vocoder_config_path: str = None, vocoder_path: Optional[str] = None,
vocoder_config_path: Optional[str] = None,
progress_bar: bool = True, progress_bar: bool = True,
gpu=False, gpu: bool = False,
): ):
"""🐸TTS python interface that allows to load and use the released models. """🐸TTS python interface that allows to load and use the released models.