From 844abb3b1d05e990400fb13aff76f4e8a4029949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 May 2021 10:38:44 +0200 Subject: [PATCH] `setup_loss()` in `layer/__init__.py` --- TTS/tts/layers/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/TTS/tts/layers/__init__.py b/TTS/tts/layers/__init__.py index e69de29b..78f56a5d 100644 --- a/TTS/tts/layers/__init__.py +++ b/TTS/tts/layers/__init__.py @@ -0,0 +1,15 @@ +from TTS.tts.layers.losses import * + + +def setup_loss(config): + if config.model.lower() in ["tacotron", "tacotron2"]: + model = TacotronLoss(config) + elif config.model.lower() == "glow_tts": + model = GlowTTSLoss() + elif config.model.lower() == "speedy_speech": + model = SpeedySpeechLoss(config) + elif config.model.lower() == "align_tts": + model = AlignTTSLoss(config) + else: + raise ValueError(f" [!] loss for model {config.model.lower()} cannot be found.") + return model