Update vits.py

This should fix the issue https://github.com/coqui-ai/TTS/issues/1986 without breaking batch data sampling.
This commit is contained in:
Dani Vera 2023-03-10 18:35:16 +01:00 committed by GitHub
parent 624513018d
commit 0d12229b64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 2 deletions

View File

@ -1628,13 +1628,23 @@ class Vits(BaseTTS):
pin_memory=False,
)
else:
loader = DataLoader(
if num_gpus > 1:
loader = DataLoader(
dataset,
batch_sampler=sampler,
sampler=sampler,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_optimizer(self) -> List: