mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #3391 from aaron-lii/multi-gpu
support multiple GPU training for XTTS
This commit is contained in:
commit
934b87bbd1
|
@ -321,7 +321,10 @@ class GPTTrainer(BaseTTS):
|
|||
def on_train_epoch_start(self, trainer):
|
||||
trainer.model.eval() # the whole model to eval
|
||||
# put gpt model in training mode
|
||||
trainer.model.xtts.gpt.train()
|
||||
if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"):
|
||||
trainer.model.module.xtts.gpt.train()
|
||||
else:
|
||||
trainer.model.xtts.gpt.train()
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
# ignore similarities.pth on clearml save/upload
|
||||
|
@ -387,7 +390,8 @@ class GPTTrainer(BaseTTS):
|
|||
else:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
sampler=sampler,
|
||||
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
|
Loading…
Reference in New Issue