diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 14c76add..78ff00c2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1628,13 +1628,23 @@ class Vits(BaseTTS): pin_memory=False, ) else: - loader = DataLoader( + if num_gpus > 1: + loader = DataLoader( dataset, - batch_sampler=sampler, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) return loader def get_optimizer(self) -> List: