From 68dbcee746a775df42250b1940b54e88ae670e62 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 13:49:49 +0200 Subject: [PATCH] import condition update for synthesis with TF --- utils/synthesis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/synthesis.py b/utils/synthesis.py index 0c68dbf2..ae3a7df7 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -1,6 +1,6 @@ import pkg_resources installed = {pkg.key for pkg in pkg_resources.working_set} -if 'tensorflow' in installed: +if 'tensorflow' in installed or 'tensorflow-gpu' in installed: import tensorflow as tf import torch import numpy as np @@ -24,9 +24,9 @@ def text_to_seqvec(text, CONFIG, use_cuda): def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None - tensor = torch.Tensor(np_array, dtype=dtype) + tensor = torch.Tensor(np_array, dtype=dtype) if cuda: - return tensor.cuda() + return tensor.cuda() return tensor