support multiple GPU training

This commit is contained in:
Aaron-Li 2023-12-08 16:55:32 +08:00
parent c99e885cc8
commit b6e929696a
1 changed files with 6 additions and 2 deletions

View File

@ -321,7 +321,10 @@ class GPTTrainer(BaseTTS):
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()
if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"):
trainer.model.module.xtts.gpt.train()
else:
trainer.model.xtts.gpt.train()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
@ -387,7 +390,8 @@ class GPTTrainer(BaseTTS):
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,