mirror of https://github.com/coqui-ai/TTS.git
update tests
This commit is contained in:
parent
ca359727bc
commit
6ccf32c2b9
|
@ -7,7 +7,8 @@ import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import load_config, setup_model
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.utils.generic_utils import setup_model
|
||||||
from TTS.utils.speakers import load_speaker_mapping
|
from TTS.utils.speakers import load_speaker_mapping
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
||||||
|
|
||||||
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||||
|
@ -172,4 +172,4 @@ class TestAudio(unittest.TestCase):
|
||||||
mel_reference = self.ap.melspectrogram(wav)
|
mel_reference = self.ap.melspectrogram(wav)
|
||||||
mel_norm = ap.melspectrogram(wav)
|
mel_norm = ap.melspectrogram(wav)
|
||||||
mel_denorm = ap._denormalize(mel_norm)
|
mel_denorm = ap._denormalize(mel_norm)
|
||||||
assert abs(mel_reference - mel_denorm).max() < 1e-4
|
assert abs(mel_reference - mel_denorm).max() < 1e-4
|
||||||
|
|
|
@ -22,7 +22,7 @@ class DemoServerTest(unittest.TestCase):
|
||||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||||
model = setup_model(num_chars, 0, config)
|
model = setup_model(num_chars, 0, config)
|
||||||
output_path = os.path.join(get_tests_output_path())
|
output_path = os.path.join(get_tests_output_path())
|
||||||
save_checkpoint(model, None, None, None, output_path, 10, 10)
|
save_checkpoint(model, None, None, None, 1, output_path)
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
self._create_random_model()
|
self._create_random_model()
|
||||||
|
|
|
@ -60,4 +60,4 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
||||||
|
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(input, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
Loading…
Reference in New Issue