diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 730506c1..59409ad0 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -5,6 +5,7 @@ import os import sys import time +import itertools import traceback from inspect import signature @@ -495,7 +496,11 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) + + if c.discriminator_model == 'hifigan_discriminator': + optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params) + else: + optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None