From 0d12229b642ead3b56294899faf8001597ab6298 Mon Sep 17 00:00:00 2001 From: Dani Vera <28764301+dveni@users.noreply.github.com> Date: Fri, 10 Mar 2023 18:35:16 +0100 Subject: [PATCH] Update vits.py This should fix the issue https://github.com/coqui-ai/TTS/issues/1986 without breaking batch data sampling. --- TTS/tts/models/vits.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 14c76add..78ff00c2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1628,13 +1628,23 @@ class Vits(BaseTTS): pin_memory=False, ) else: - loader = DataLoader( + if num_gpus > 1: + loader = DataLoader( dataset, - batch_sampler=sampler, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) return loader def get_optimizer(self) -> List: