From d2b6326b8bc9a4f7f9aabd0137e9f50b76b7d388 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 23 Apr 2021 07:54:39 -0300 Subject: [PATCH] change optimizer initialization for compatibility with Hifi-GAN official implementation --- TTS/bin/train_vocoder_gan.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 730506c1..59409ad0 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -5,6 +5,7 @@ import os import sys import time +import itertools import traceback from inspect import signature @@ -495,7 +496,11 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) + + if c.discriminator_model == 'hifigan_discriminator': + optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params) + else: + optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None