mirror of https://github.com/coqui-ai/TTS.git
Skip TF tests on GPU
This commit is contained in:
parent
1ebf9ec6bf
commit
7ec23e69d4
|
@ -38,6 +38,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
|||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
||||
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids
|
||||
|
||||
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||
def test_train_step(self):
|
||||
"""test forward pass"""
|
||||
(
|
||||
|
@ -70,6 +71,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
|||
# inference pass
|
||||
output = model(chars_seq, training=False)
|
||||
|
||||
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||
def test_forward_attention(
|
||||
self,
|
||||
):
|
||||
|
@ -103,6 +105,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
|||
# inference pass
|
||||
output = model(chars_seq, training=False)
|
||||
|
||||
@unittest.skipIf(use_cuda, " [!] Skip Test: TfLite conversion does not work on GPU.")
|
||||
def test_tflite_conversion(
|
||||
self,
|
||||
): # pylint:disable=no-self-use
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
|
||||
def test_melgan_generator():
|
||||
hop_length = 256
|
||||
model = MelganGenerator()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import soundfile as sf
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from librosa.core import load
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||
|
@ -9,8 +11,10 @@ from TTS.vocoder.tf.layers.pqmf import PQMF
|
|||
|
||||
TESTS_PATH = get_tests_path()
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
@unittest.skipIf(use_cuda, " [!] Skip Test: Loosy TF support.")
|
||||
def test_pqmf():
|
||||
w, sr = load(WAV_FILE)
|
||||
|
||||
|
|
Loading…
Reference in New Issue