Ensures that only GPT model is in training mode during training

This commit is contained in:
Edresson Casanova 2023-11-16 14:41:12 -03:00
parent 7e4375da2b
commit b87a665a40
2 changed files with 5 additions and 4 deletions

View File

@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
batch["cond_idxs"] = None
return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613
# guarante that dvae will be in eval mode after .train() on evaluation end
self.dvae = self.dvae.eval()
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload

View File

@ -27,7 +27,7 @@ pandas>=1.4,<2.0
# deps for training
matplotlib==3.7.*
# coqui stack
trainer
trainer>=0.0.32
# config management
coqpit>=0.0.16
# chinese g2p deps