mirror of https://github.com/coqui-ai/TTS.git
support multiple GPU training
This commit is contained in:
parent
c99e885cc8
commit
b6e929696a
|
@ -321,7 +321,10 @@ class GPTTrainer(BaseTTS):
|
||||||
def on_train_epoch_start(self, trainer):
|
def on_train_epoch_start(self, trainer):
|
||||||
trainer.model.eval() # the whole model to eval
|
trainer.model.eval() # the whole model to eval
|
||||||
# put gpt model in training mode
|
# put gpt model in training mode
|
||||||
trainer.model.xtts.gpt.train()
|
if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"):
|
||||||
|
trainer.model.module.xtts.gpt.train()
|
||||||
|
else:
|
||||||
|
trainer.model.xtts.gpt.train()
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
|
@ -387,7 +390,8 @@ class GPTTrainer(BaseTTS):
|
||||||
else:
|
else:
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_sampler=sampler,
|
sampler=sampler,
|
||||||
|
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
|
Loading…
Reference in New Issue