Handle when no batch sampler (#1882)

This commit is contained in:
Eren Gölge 2022-08-18 11:26:04 +02:00 committed by GitHub
parent 7442bcefa5
commit fcb0bb58ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 8 deletions

View File

@ -1613,14 +1613,24 @@ class Vits(BaseTTS):
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
if sampler is None:
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
collate_fn=dataset.collate_fn,
drop_last=False, # setting this False might cause issues in AMP training.
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_optimizer(self) -> List: