diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c24fec68..212e7779 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -225,6 +225,8 @@ class VitsArgs(Coqpit): freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False @@ -787,9 +789,11 @@ class Vits(BaseTTS): if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False - for param in self.emb_l.parameters(): - param.requires_grad = False - + + if hasattr(self, 'emb_l'): + for param in self.emb_l.parameters(): + param.requires_grad = False + if self.args.freeze_PE: for param in self.posterior_encoder.parameters(): param.requires_grad = False @@ -798,6 +802,14 @@ class Vits(BaseTTS): for param in self.duration_predictor.parameters(): param.requires_grad = False + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"]