Make lint

This commit is contained in:
Eren Gölge 2022-01-13 17:43:05 +00:00
parent 146fbfd7c9
commit 21940952bf
1 changed files with 18 additions and 14 deletions

View File

@ -1,7 +1,6 @@
import copy import copy
import os import os
import unittest import unittest
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
import torch import torch
@ -11,6 +10,7 @@ from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
@ -337,7 +337,7 @@ class TestVits(unittest.TestCase):
count += 1 count += 1
def _create_batch(self, config, batch_size): def _create_batch(self, config, batch_size):
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size) input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
batch = {} batch = {}
batch["text_input"] = input_dummy batch["text_input"] = input_dummy
batch["text_lengths"] = input_lengths batch["text_lengths"] = input_lengths
@ -441,22 +441,26 @@ class TestVits(unittest.TestCase):
self.assertEqual(model.num_speakers, 2) self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs( config = VitsConfig(
model_args=VitsArgs(
num_chars=32, num_chars=32,
num_speakers=2, num_speakers=2,
use_speaker_embedding=True, use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)) )
)
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config, verbose=False).to(device)
self.assertEqual(model.num_speakers, 10) self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs( config = VitsConfig(
model_args=VitsArgs(
num_chars=32, num_chars=32,
use_d_vector_file=True, use_d_vector_file=True,
d_vector_dim=256, d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)) )
)
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1) self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))