From 0d12229b642ead3b56294899faf8001597ab6298 Mon Sep 17 00:00:00 2001 From: Dani Vera <28764301+dveni@users.noreply.github.com> Date: Fri, 10 Mar 2023 18:35:16 +0100 Subject: [PATCH 1/2] Update vits.py This should fix the issue https://github.com/coqui-ai/TTS/issues/1986 without breaking batch data sampling. --- TTS/tts/models/vits.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 14c76add..78ff00c2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1628,13 +1628,23 @@ class Vits(BaseTTS): pin_memory=False, ) else: - loader = DataLoader( + if num_gpus > 1: + loader = DataLoader( dataset, - batch_sampler=sampler, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) return loader def get_optimizer(self) -> List: From dfb48737fbe40c341dff52b98c628db20257f8fb Mon Sep 17 00:00:00 2001 From: Daniel Vera Nieto Date: Mon, 13 Mar 2023 16:11:15 +0100 Subject: [PATCH 2/2] Style fixed --- TTS/tts/models/vits.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 78ff00c2..7500da61 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1630,13 +1630,13 @@ class Vits(BaseTTS): else: if num_gpus > 1: loader = DataLoader( - dataset, - sampler=sampler, - batch_size=config.eval_batch_size if is_eval else config.batch_size, - collate_fn=dataset.collate_fn, - num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, - pin_memory=False, - ) + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) else: loader = DataLoader( dataset,