From de41165af46f8a6e4b617e3958afe4b7cdba44d8 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Sun, 19 Sep 2021 23:35:31 +0200 Subject: [PATCH] freeze vits parts --- TTS/tts/models/vits.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 334e4526..c24fec68 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -222,6 +222,9 @@ class VitsArgs(Coqpit): speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" fine_tuning_mode: int = 0 + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False @@ -781,6 +784,20 @@ class Vits(BaseTTS): self.waveform_decoder.train() self.disc.train() + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"]