Merge pull request #3391 from aaron-lii/multi-gpu

support multiple GPU training for XTTS
This commit is contained in:
Eren Gölge 2023-12-12 13:51:26 +01:00 committed by GitHub
commit 934b87bbd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -321,7 +321,10 @@ class GPTTrainer(BaseTTS):
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()
if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"):
trainer.model.module.xtts.gpt.train()
else:
trainer.model.xtts.gpt.train()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
@ -387,7 +390,8 @@ class GPTTrainer(BaseTTS):
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,