Make lint

This commit is contained in:
Eren Gölge 2022-01-12 14:30:53 +00:00
parent d0eb3e4ef2
commit 2fe16de8e3
7 changed files with 45 additions and 30 deletions

View File

@ -1,14 +1,9 @@
import os import os
import torch from TTS.config import load_config, register_config
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
def main(): def main():

View File

@ -292,7 +292,6 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None):
"""https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"""
file_ext = "flac" file_ext = "flac"
test_speakers = meta_files
items = [] items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files: for meta_file in meta_files:

View File

@ -123,7 +123,9 @@ class GlowTTS(BaseTTS):
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
) )
if self.speaker_manager is not None: if self.speaker_manager is not None:
assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager." assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") print(" > Init speaker_embedding layer.")

View File

@ -1,5 +1,4 @@
import math import math
import random
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -269,10 +268,20 @@ class Vits(BaseTTS):
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
Examples: Examples:
Init only model layers.
>>> from TTS.tts.configs.vits_config import VitsConfig >>> from TTS.tts.configs.vits_config import VitsConfig
>>> from TTS.tts.models.vits import Vits >>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig() >>> config = VitsConfig()
>>> model = Vits(config) >>> model = Vits(config)
Fully init a model ready for action. All the class attributes and class members
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
>>> from TTS.tts.configs.vits_config import VitsConfig
>>> from TTS.tts.models.vits import Vits
>>> config = VitsConfig()
>>> model = Vits.init_from_config(config)
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
@ -908,13 +917,10 @@ class Vits(BaseTTS):
aux_inputs["text"], aux_inputs["text"],
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
self.ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
language_id=aux_inputs["language_id"], language_id=aux_inputs["language_id"],
language_name=aux_inputs["language_name"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
).values() ).values()

View File

@ -148,7 +148,7 @@ class TTSTokenizer:
# init cleaners # init cleaners
text_cleaner = None text_cleaner = None
if isinstance(config.text_cleaner, (str, list)): if isinstance(config.text_cleaner, (str, list)):
text_cleaner = getattr(config, "text_cleaner") text_cleaner = getattr(cleaners, config.text_cleaner)
# init characters # init characters
if characters is None: if characters is None:

View File

@ -122,13 +122,9 @@ class Synthesizer(object):
speaker_manager = self._init_speaker_encoder(speaker_manager) speaker_manager = self._init_speaker_encoder(speaker_manager)
if language_manager is not None: if language_manager is not None:
self.tts_model = setup_tts_model( self.tts_model = setup_tts_model(config=self.tts_config)
config=self.tts_config,
speaker_manager=speaker_manager,
language_manager=language_manager,
)
else: else:
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) self.tts_model = setup_tts_model(config=self.tts_config)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
@ -333,7 +329,6 @@ class Synthesizer(object):
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
speaker_id=speaker_id, speaker_id=speaker_id,
language_id=language_id, language_id=language_id,
language_name=language_name,
style_wav=style_wav, style_wav=style_wav,
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
d_vector=speaker_embedding, d_vector=speaker_embedding,

View File

@ -1,8 +1,6 @@
import copy import copy
import os import os
import unittest import unittest
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
import torch import torch
from torch import optim from torch import optim
@ -11,7 +9,9 @@ from tests import get_tests_data_path, get_tests_input_path, get_tests_output_pa
from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
# pylint: disable=unused-variable # pylint: disable=unused-variable
@ -31,7 +31,8 @@ def count_parameters(model):
class TestGlowTTS(unittest.TestCase): class TestGlowTTS(unittest.TestCase):
def _create_inputs(self): @staticmethod
def _create_inputs():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128 input_lengths[-1] = 128
@ -40,7 +41,8 @@ class TestGlowTTS(unittest.TestCase):
speaker_ids = torch.randint(0, 5, (8,)).long().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
def _check_parameter_changes(self, model, model_ref): @staticmethod
def _check_parameter_changes(model, model_ref):
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
@ -166,7 +168,7 @@ class TestGlowTTS(unittest.TestCase):
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
output_shape = outputs["model_outputs"].shape output_shape = outputs["model_outputs"].shape
self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2]) self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2])
self.assertEqual(outputs["logdet"], None) self.assertEqual(outputs["logdet"], None)
self.assertEqual(outputs["y_mean"].shape, output_shape) self.assertEqual(outputs["y_mean"].shape, output_shape)
self.assertEqual(outputs["y_log_scale"].shape, output_shape) self.assertEqual(outputs["y_log_scale"].shape, output_shape)
@ -185,7 +187,12 @@ class TestGlowTTS(unittest.TestCase):
def test_inference_with_d_vector(self): def test_inference_with_d_vector(self):
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
d_vector = torch.rand(8, 256).to(device) d_vector = torch.rand(8, 256).to(device)
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) config = GlowTTSConfig(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config, verbose=False).to(device)
model.eval() model.eval()
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
@ -268,7 +275,9 @@ class TestGlowTTS(unittest.TestCase):
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config, verbose=False).to(device)
model.run_data_dep_init = False model.run_data_dep_init = False
model.train() model.train()
logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log") logger = TensorboardLogger(
log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log"
)
criterion = model.get_criterion() criterion = model.get_criterion()
outputs, _ = model.train_step(batch, criterion) outputs, _ = model.train_step(batch, criterion)
model.train_log(batch, outputs, logger, None, 1) model.train_log(batch, outputs, logger, None, 1)
@ -316,14 +325,23 @@ class TestGlowTTS(unittest.TestCase):
self.assertTrue(model.num_speakers == 2) self.assertTrue(model.num_speakers == 2)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")) config = GlowTTSConfig(
num_chars=32,
num_speakers=2,
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 10) self.assertTrue(model.num_speakers == 10)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) config = GlowTTSConfig(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1) self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.c_in_channels == config.d_vector_dim) self.assertTrue(model.c_in_channels == config.d_vector_dim)