mirror of https://github.com/coqui-ai/TTS.git
add get_cuda()
This commit is contained in:
parent
21dd4d7960
commit
a21ac883dd
|
@ -1,4 +1,3 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
|
@ -8,10 +7,17 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda():
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
return use_cuda, device
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
|
|
Loading…
Reference in New Issue