From 39aff6685efb124fc7de81535f4033ca3fbbf37e Mon Sep 17 00:00:00 2001 From: Edresson Date: Sun, 19 Sep 2021 21:06:58 -0300 Subject: [PATCH] Add freeze vocoder generator and flow-based decoder option --- TTS/tts/models/vits.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c24fec68..212e7779 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -225,6 +225,8 @@ class VitsArgs(Coqpit): freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False @@ -787,9 +789,11 @@ class Vits(BaseTTS): if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False - for param in self.emb_l.parameters(): - param.requires_grad = False - + + if hasattr(self, 'emb_l'): + for param in self.emb_l.parameters(): + param.requires_grad = False + if self.args.freeze_PE: for param in self.posterior_encoder.parameters(): param.requires_grad = False @@ -798,6 +802,14 @@ class Vits(BaseTTS): for param in self.duration_predictor.parameters(): param.requires_grad = False + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"]