get_device_id() for tests

This commit is contained in:
Eren Gölge 2021-05-10 15:24:34 +02:00
parent 6e980b49c4
commit 87384c6008
1 changed files with 11 additions and 0 deletions

View File

@ -1,5 +1,16 @@
import os
from TTS.utils.generic_utils import get_cuda
def get_device_id():
use_cuda, _ = get_cuda()
if use_cuda:
GPU_ID = "0"
else:
GPU_ID = ""
return GPU_ID
def get_tests_path():
"""Returns the path to the test directory."""