diff --git a/utils/synthesis.py b/utils/synthesis.py index 0c68dbf2..ae3a7df7 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -1,6 +1,6 @@ import pkg_resources installed = {pkg.key for pkg in pkg_resources.working_set} -if 'tensorflow' in installed: +if 'tensorflow' in installed or 'tensorflow-gpu' in installed: import tensorflow as tf import torch import numpy as np @@ -24,9 +24,9 @@ def text_to_seqvec(text, CONFIG, use_cuda): def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None - tensor = torch.Tensor(np_array, dtype=dtype) + tensor = torch.Tensor(np_array, dtype=dtype) if cuda: - return tensor.cuda() + return tensor.cuda() return tensor