mirror of https://github.com/coqui-ai/TTS.git
freeze vits parts
This commit is contained in:
parent
9d2c445e3d
commit
de41165af4
|
@ -222,6 +222,9 @@ class VitsArgs(Coqpit):
|
|||
speaker_encoder_config_path: str = ""
|
||||
speaker_encoder_model_path: str = ""
|
||||
fine_tuning_mode: int = 0
|
||||
freeze_encoder: bool = False
|
||||
freeze_DP: bool = False
|
||||
freeze_PE: bool = False
|
||||
|
||||
|
||||
|
||||
|
@ -781,6 +784,20 @@ class Vits(BaseTTS):
|
|||
self.waveform_decoder.train()
|
||||
self.disc.train()
|
||||
|
||||
if self.args.freeze_encoder:
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.emb_l.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_PE:
|
||||
for param in self.posterior_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_DP:
|
||||
for param in self.duration_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
|
Loading…
Reference in New Issue