Make lint

This commit is contained in:
Eren Gölge 2022-01-12 14:30:53 +00:00
parent d0eb3e4ef2
commit 2fe16de8e3
7 changed files with 45 additions and 30 deletions

View File

@ -1,14 +1,9 @@
import os
import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
def main():

View File

@ -292,7 +292,6 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None):
"""https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"""
file_ext = "flac"
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:

View File

@ -123,7 +123,9 @@ class GlowTTS(BaseTTS):
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
if self.speaker_manager is not None:
assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager."
assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")

View File

@ -1,5 +1,4 @@
import math
import random
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Tuple, Union
@ -269,10 +268,20 @@ class Vits(BaseTTS):
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
Examples:
Init only model layers.
>>> from TTS.tts.configs.vits_config import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits(config)
Fully init a model ready for action. All the class attributes and class members
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
>>> from TTS.tts.configs.vits_config import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits.init_from_config(config)
"""
# pylint: disable=dangerous-default-value
@ -908,13 +917,10 @@ class Vits(BaseTTS):
aux_inputs["text"],
self.config,
"cuda" in str(next(self.parameters()).device),
self.ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"],
language_name=aux_inputs["language_name"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()

View File

@ -148,7 +148,7 @@ class TTSTokenizer:
# init cleaners
text_cleaner = None
if isinstance(config.text_cleaner, (str, list)):
text_cleaner = getattr(config, "text_cleaner")
text_cleaner = getattr(cleaners, config.text_cleaner)
# init characters
if characters is None:

View File

@ -122,13 +122,9 @@ class Synthesizer(object):
speaker_manager = self._init_speaker_encoder(speaker_manager)
if language_manager is not None:
self.tts_model = setup_tts_model(
config=self.tts_config,
speaker_manager=speaker_manager,
language_manager=language_manager,
)
self.tts_model = setup_tts_model(config=self.tts_config)
else:
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
self.tts_model = setup_tts_model(config=self.tts_config)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
@ -333,7 +329,6 @@ class Synthesizer(object):
use_cuda=self.use_cuda,
speaker_id=speaker_id,
language_id=language_id,
language_name=language_name,
style_wav=style_wav,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,

View File

@ -1,8 +1,6 @@
import copy
import os
import unittest
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
import torch
from torch import optim
@ -11,7 +9,9 @@ from tests import get_tests_data_path, get_tests_input_path, get_tests_output_pa
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
# pylint: disable=unused-variable
@ -31,7 +31,8 @@ def count_parameters(model):
class TestGlowTTS(unittest.TestCase):
def _create_inputs(self):
@staticmethod
def _create_inputs():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
@ -40,7 +41,8 @@ class TestGlowTTS(unittest.TestCase):
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
def _check_parameter_changes(self, model, model_ref):
@staticmethod
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
@ -166,7 +168,7 @@ class TestGlowTTS(unittest.TestCase):
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
output_shape = outputs["model_outputs"].shape
self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2])
self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2])
self.assertEqual(outputs["logdet"], None)
self.assertEqual(outputs["y_mean"].shape, output_shape)
self.assertEqual(outputs["y_log_scale"].shape, output_shape)
@ -185,7 +187,12 @@ class TestGlowTTS(unittest.TestCase):
def test_inference_with_d_vector(self):
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
d_vector = torch.rand(8, 256).to(device)
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
config = GlowTTSConfig(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model.eval()
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
@ -268,7 +275,9 @@ class TestGlowTTS(unittest.TestCase):
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model.run_data_dep_init = False
model.train()
logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log")
logger = TensorboardLogger(
log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log"
)
criterion = model.get_criterion()
outputs, _ = model.train_step(batch, criterion)
model.train_log(batch, outputs, logger, None, 1)
@ -316,14 +325,23 @@ class TestGlowTTS(unittest.TestCase):
self.assertTrue(model.num_speakers == 2)
self.assertTrue(hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"))
config = GlowTTSConfig(
num_chars=32,
num_speakers=2,
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 10)
self.assertTrue(hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
config = GlowTTSConfig(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.c_in_channels == config.d_vector_dim)