From adcc2f8299ff86429b89803180c591147c2b0c7a Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 21 Apr 2022 10:03:37 -0300 Subject: [PATCH] Add the period for VITS multi-period discriminator in model_args --- TTS/server/server.py | 2 +- TTS/tts/layers/vits/discriminator.py | 4 +--- TTS/tts/models/vits.py | 6 +++++- tests/tts_tests/test_vits_multilingual_speaker_emb_train.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/TTS/server/server.py b/TTS/server/server.py index 33896e4e..89fce493 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -114,7 +114,7 @@ synthesizer = Synthesizer( use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None ) -print("Multispeaker?", use_multi_speaker, synthesizer.tts_model.num_speakers) + speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) # TODO: set this from SpeakerManager use_gst = synthesizer.tts_config.get("use_gst", False) diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index e9d54713..148f283c 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -58,10 +58,8 @@ class VitsDiscriminator(nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, use_spectral_norm=False): + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): super().__init__() - periods = [2, 3, 5, 7, 11] - self.nets = nn.ModuleList() self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7ae7d9bc..8dcde6bd 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -362,6 +362,9 @@ class VitsArgs(Coqpit): upsample_kernel_sizes_decoder (List[int]): Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + use_sdp (bool): Use Stochastic Duration Predictor. Defaults to True. @@ -475,6 +478,7 @@ class VitsArgs(Coqpit): upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) upsample_initial_channel_decoder: int = 512 upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) use_sdp: bool = True noise_scale: float = 1.0 inference_noise_scale: float = 0.667 @@ -628,7 +632,7 @@ class Vits(BaseTTS): ) if self.args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) + self.disc = VitsDiscriminator(periods=self.args.periods_multi_period_discriminator, use_spectral_norm=self.args.use_spectral_norm_disriminator) if self.args.TTS_part_sample_rate: self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 9f9eead9..7b9b6335 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -73,10 +73,11 @@ config.use_language_weighted_sampler = True # test upsample config.model_args.TTS_part_sample_rate = 11025 -config.model_args.interpolate_z = True +config.model_args.interpolate_z = False config.model_args.detach_z_vocoder = True -config.model_args.upsample_rates_decoder = [8, 8, 2, 2] +config.model_args.upsample_rates_decoder = [8, 8, 4, 2] +config.model_args.periods_multi_period_discriminator = [2, 3, 5, 7, 11, 13, 17, 19, 23] config.save_json(config_path)