add get_cuda()

This commit is contained in:
Eren Gölge 2021-05-10 15:18:58 +02:00
parent 21dd4d7960
commit a21ac883dd
1 changed files with 7 additions and 1 deletions

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import glob
@ -8,10 +7,17 @@ import re
import shutil
import subprocess
import sys
import torch
from pathlib import Path
from typing import List
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")