mirror of https://github.com/coqui-ai/TTS.git
Ensures that only GPT model is in training mode during XTTS GPT training (#3241)
* Ensures that only GPT model is in training mode during training * Fix parallel wavegan unit test
This commit is contained in:
parent
14579a4607
commit
11283fce07
|
@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
||||||
batch["cond_idxs"] = None
|
batch["cond_idxs"] = None
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
def on_train_epoch_start(self, trainer):
|
||||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
trainer.model.eval() # the whole model to eval
|
||||||
self.dvae = self.dvae.eval()
|
# put gpt model in training mode
|
||||||
|
trainer.model.xtts.gpt.train()
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
|
|
|
@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
steps_to_start_discriminator: int = 200000
|
steps_to_start_discriminator: int = 200000
|
||||||
|
target_loss: str = "loss_1"
|
||||||
|
|
||||||
# LOSS PARAMETERS - overrides
|
# LOSS PARAMETERS - overrides
|
||||||
use_stft_loss: bool = True
|
use_stft_loss: bool = True
|
||||||
|
|
|
@ -27,7 +27,7 @@ pandas>=1.4,<2.0
|
||||||
# deps for training
|
# deps for training
|
||||||
matplotlib>=3.7.0
|
matplotlib>=3.7.0
|
||||||
# coqui stack
|
# coqui stack
|
||||||
trainer
|
trainer>=0.0.32
|
||||||
# config management
|
# config management
|
||||||
coqpit>=0.0.16
|
coqpit>=0.0.16
|
||||||
# chinese g2p deps
|
# chinese g2p deps
|
||||||
|
|
Loading…
Reference in New Issue