refactor: only use keyword args in Synthesizer

This commit is contained in:
Enno Hermann 2024-11-29 16:27:14 +01:00
parent 6927e0bb89
commit 546f43cb25
3 changed files with 14 additions and 13 deletions

View File

@ -407,18 +407,18 @@ def main():
# load models # load models
synthesizer = Synthesizer( synthesizer = Synthesizer(
tts_path, tts_checkpoint=tts_path,
tts_config_path, tts_config_path=tts_config_path,
speakers_file_path, tts_speakers_file=speakers_file_path,
language_ids_file_path, tts_languages_file=language_ids_file_path,
vocoder_path, vocoder_checkpoint=vocoder_path,
vocoder_config_path, vocoder_config=vocoder_config_path,
encoder_path, encoder_checkpoint=encoder_path,
encoder_config_path, encoder_config=encoder_config_path,
vc_path, vc_checkpoint=vc_path,
vc_config_path, vc_config=vc_config_path,
model_dir, model_dir=model_dir,
args.voice_dir, voice_dir=args.voice_dir,
).to(device) ).to(device)
# query speaker ids of a multi-speaker model. # query speaker ids of a multi-speaker model.

View File

@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class Synthesizer(nn.Module): class Synthesizer(nn.Module):
def __init__( def __init__(
self, self,
*,
tts_checkpoint: str = "", tts_checkpoint: str = "",
tts_config_path: str = "", tts_config_path: str = "",
tts_speakers_file: str = "", tts_speakers_file: str = "",

View File

@ -23,7 +23,7 @@ class SynthesizerTest(unittest.TestCase):
tts_root_path = get_tests_input_path() tts_root_path = get_tests_input_path()
tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth") tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth")
tts_config = os.path.join(tts_root_path, "dummy_model_config.json") tts_config = os.path.join(tts_root_path, "dummy_model_config.json")
synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None) synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config)
synthesizer.tts("Better this test works!!") synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self): def test_split_into_sentences(self):