From a21ac883dd698835adc2d31fe8fe69469570c71d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 10 May 2021 15:18:58 +0200 Subject: [PATCH] add get_cuda() --- TTS/utils/generic_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index e8beff88..5473d32d 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- import datetime import glob @@ -8,10 +7,17 @@ import re import shutil import subprocess import sys +import torch from pathlib import Path from typing import List +def get_cuda(): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return use_cuda, device + + def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8")