mirror of https://github.com/coqui-ai/TTS.git
Ensures that only GPT model is in training mode during training
This commit is contained in:
parent
7e4375da2b
commit
b87a665a40
|
@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
||||||
batch["cond_idxs"] = None
|
batch["cond_idxs"] = None
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
def on_train_epoch_start(self, trainer):
|
||||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
trainer.model.eval() # the whole model to eval
|
||||||
self.dvae = self.dvae.eval()
|
# put gpt model in training mode
|
||||||
|
trainer.model.xtts.gpt.train()
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
|
|
|
@ -27,7 +27,7 @@ pandas>=1.4,<2.0
|
||||||
# deps for training
|
# deps for training
|
||||||
matplotlib==3.7.*
|
matplotlib==3.7.*
|
||||||
# coqui stack
|
# coqui stack
|
||||||
trainer
|
trainer>=0.0.32
|
||||||
# config management
|
# config management
|
||||||
coqpit>=0.0.16
|
coqpit>=0.0.16
|
||||||
# chinese g2p deps
|
# chinese g2p deps
|
||||||
|
|
Loading…
Reference in New Issue