mirror of https://github.com/coqui-ai/TTS.git
Make lint
This commit is contained in:
parent
146fbfd7c9
commit
21940952bf
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -11,6 +10,7 @@ from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
|||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
@ -337,7 +337,7 @@ class TestVits(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
def _create_batch(self, config, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
|
||||
batch = {}
|
||||
batch["text_input"] = input_dummy
|
||||
batch["text_lengths"] = input_lengths
|
||||
|
@ -441,22 +441,26 @@ class TestVits(unittest.TestCase):
|
|||
self.assertEqual(model.num_speakers, 2)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(
|
||||
config = VitsConfig(
|
||||
model_args=VitsArgs(
|
||||
num_chars=32,
|
||||
num_speakers=2,
|
||||
use_speaker_embedding=True,
|
||||
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
|
||||
))
|
||||
)
|
||||
)
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertEqual(model.num_speakers, 10)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(
|
||||
config = VitsConfig(
|
||||
model_args=VitsArgs(
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
d_vector_dim=256,
|
||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||
))
|
||||
)
|
||||
)
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 1)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
|
|
Loading…
Reference in New Issue