add get_cuda()

This commit is contained in:
Eren Gölge 2021-05-10 15:18:58 +02:00
parent 21dd4d7960
commit a21ac883dd
1 changed files with 7 additions and 1 deletions

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime
import glob import glob
@ -8,10 +7,17 @@ import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
import torch
from pathlib import Path from pathlib import Path
from typing import List from typing import List
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
def get_git_branch(): def get_git_branch():
try: try:
out = subprocess.check_output(["git", "branch"]).decode("utf8") out = subprocess.check_output(["git", "branch"]).decode("utf8")