import condition update for synthesis with TF

This commit is contained in:
erogol 2020-05-12 13:49:49 +02:00
parent 84c5c4a587
commit 68dbcee746
1 changed files with 3 additions and 3 deletions

View File

@ -1,6 +1,6 @@
import pkg_resources
installed = {pkg.key for pkg in pkg_resources.working_set}
if 'tensorflow' in installed:
if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
import tensorflow as tf
import torch
import numpy as np
@ -24,9 +24,9 @@ def text_to_seqvec(text, CONFIG, use_cuda):
def numpy_to_torch(np_array, dtype, cuda=False):
if np_array is None:
return None
tensor = torch.Tensor(np_array, dtype=dtype)
tensor = torch.Tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
return tensor.cuda()
return tensor