mirror of https://github.com/coqui-ai/TTS.git
Skip TF tests on GPU
This commit is contained in:
parent
1ebf9ec6bf
commit
7ec23e69d4
|
@ -38,6 +38,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
||||||
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids
|
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids
|
||||||
|
|
||||||
|
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
"""test forward pass"""
|
"""test forward pass"""
|
||||||
(
|
(
|
||||||
|
@ -70,6 +71,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(chars_seq, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
||||||
|
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||||
def test_forward_attention(
|
def test_forward_attention(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
|
@ -103,6 +105,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(chars_seq, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
||||||
|
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||||
def test_tflite_conversion(
|
def test_tflite_conversion(
|
||||||
self,
|
self,
|
||||||
): # pylint:disable=no-self-use
|
): # pylint:disable=no-self-use
|
||||||
|
|
|
@ -1,9 +1,15 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
|
||||||
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
||||||
|
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
|
||||||
def test_melgan_generator():
|
def test_melgan_generator():
|
||||||
hop_length = 256
|
hop_length = 256
|
||||||
model = MelganGenerator()
|
model = MelganGenerator()
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
from librosa.core import load
|
from librosa.core import load
|
||||||
|
|
||||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
|
@ -9,8 +11,10 @@ from TTS.vocoder.tf.layers.pqmf import PQMF
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
|
||||||
def test_pqmf():
|
def test_pqmf():
|
||||||
w, sr = load(WAV_FILE)
|
w, sr = load(WAV_FILE)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue