mirror of https://github.com/coqui-ai/TTS.git
Make lint
This commit is contained in:
parent
146fbfd7c9
commit
21940952bf
|
@ -1,7 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -11,6 +10,7 @@ from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
@ -337,7 +337,7 @@ class TestVits(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
def _create_batch(self, config, batch_size):
|
def _create_batch(self, config, batch_size):
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size)
|
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
|
||||||
batch = {}
|
batch = {}
|
||||||
batch["text_input"] = input_dummy
|
batch["text_input"] = input_dummy
|
||||||
batch["text_lengths"] = input_lengths
|
batch["text_lengths"] = input_lengths
|
||||||
|
@ -441,22 +441,26 @@ class TestVits(unittest.TestCase):
|
||||||
self.assertEqual(model.num_speakers, 2)
|
self.assertEqual(model.num_speakers, 2)
|
||||||
self.assertTrue(hasattr(model, "emb_g"))
|
self.assertTrue(hasattr(model, "emb_g"))
|
||||||
|
|
||||||
config = VitsConfig(model_args=VitsArgs(
|
config = VitsConfig(
|
||||||
|
model_args=VitsArgs(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=2,
|
num_speakers=2,
|
||||||
use_speaker_embedding=True,
|
use_speaker_embedding=True,
|
||||||
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
|
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||||
self.assertEqual(model.num_speakers, 10)
|
self.assertEqual(model.num_speakers, 10)
|
||||||
self.assertTrue(hasattr(model, "emb_g"))
|
self.assertTrue(hasattr(model, "emb_g"))
|
||||||
|
|
||||||
config = VitsConfig(model_args=VitsArgs(
|
config = VitsConfig(
|
||||||
|
model_args=VitsArgs(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
use_d_vector_file=True,
|
use_d_vector_file=True,
|
||||||
d_vector_dim=256,
|
d_vector_dim=256,
|
||||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||||
self.assertTrue(model.num_speakers == 1)
|
self.assertTrue(model.num_speakers == 1)
|
||||||
self.assertTrue(not hasattr(model, "emb_g"))
|
self.assertTrue(not hasattr(model, "emb_g"))
|
||||||
|
|
Loading…
Reference in New Issue