Handle when no batch sampler (#1882)

This commit is contained in:
Eren Gölge 2022-08-18 11:26:04 +02:00 committed by GitHub
parent 7442bcefa5
commit fcb0bb58ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 8 deletions

View File

@ -1613,14 +1613,24 @@ class Vits(BaseTTS):
# get samplers # get samplers
sampler = self.get_sampler(config, dataset, num_gpus) sampler = self.get_sampler(config, dataset, num_gpus)
if sampler is None:
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_sampler=sampler, batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn, shuffle=False, # shuffle is done in the dataset.
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, collate_fn=dataset.collate_fn,
pin_memory=False, drop_last=False, # setting this False might cause issues in AMP training.
) num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader return loader
def get_optimizer(self) -> List: def get_optimizer(self) -> List: