Make lint

This commit is contained in:
Eren Gölge 2022-01-13 17:43:05 +00:00
parent 146fbfd7c9
commit 21940952bf
1 changed files with 18 additions and 14 deletions

View File

@ -1,7 +1,6 @@
import copy
import os
import unittest
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
import torch
@ -11,6 +10,7 @@ from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
@ -337,7 +337,7 @@ class TestVits(unittest.TestCase):
count += 1
def _create_batch(self, config, batch_size):
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size)
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
batch = {}
batch["text_input"] = input_dummy
batch["text_lengths"] = input_lengths
@ -441,22 +441,26 @@ class TestVits(unittest.TestCase):
self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs(
config = VitsConfig(
model_args=VitsArgs(
num_chars=32,
num_speakers=2,
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
))
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs(
config = VitsConfig(
model_args=VitsArgs(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
))
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))