mirror of https://github.com/coqui-ai/TTS.git
Make lint
This commit is contained in:
parent
d0eb3e4ef2
commit
2fe16de8e3
|
@ -1,14 +1,9 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
from TTS.config import load_config, register_config
|
||||||
|
|
||||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -292,7 +292,6 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||||
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None):
|
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None):
|
||||||
"""https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"""
|
"""https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"""
|
||||||
file_ext = "flac"
|
file_ext = "flac"
|
||||||
test_speakers = meta_files
|
|
||||||
items = []
|
items = []
|
||||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||||
for meta_file in meta_files:
|
for meta_file in meta_files:
|
||||||
|
|
|
@ -123,7 +123,9 @@ class GlowTTS(BaseTTS):
|
||||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||||
)
|
)
|
||||||
if self.speaker_manager is not None:
|
if self.speaker_manager is not None:
|
||||||
assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager."
|
assert (
|
||||||
|
config.d_vector_dim == self.speaker_manager.d_vector_dim
|
||||||
|
), " [!] d-vector dimension mismatch b/w config and speaker manager."
|
||||||
# init speaker embedding layer
|
# init speaker embedding layer
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
print(" > Init speaker_embedding layer.")
|
print(" > Init speaker_embedding layer.")
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
import random
|
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
@ -269,10 +268,20 @@ class Vits(BaseTTS):
|
||||||
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
Init only model layers.
|
||||||
|
|
||||||
>>> from TTS.tts.configs.vits_config import VitsConfig
|
>>> from TTS.tts.configs.vits_config import VitsConfig
|
||||||
>>> from TTS.tts.models.vits import Vits
|
>>> from TTS.tts.models.vits import Vits
|
||||||
>>> config = VitsConfig()
|
>>> config = VitsConfig()
|
||||||
>>> model = Vits(config)
|
>>> model = Vits(config)
|
||||||
|
|
||||||
|
Fully init a model ready for action. All the class attributes and class members
|
||||||
|
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
>>> from TTS.tts.models.vits import Vits
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
>>> model = Vits.init_from_config(config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
@ -908,13 +917,10 @@ class Vits(BaseTTS):
|
||||||
aux_inputs["text"],
|
aux_inputs["text"],
|
||||||
self.config,
|
self.config,
|
||||||
"cuda" in str(next(self.parameters()).device),
|
"cuda" in str(next(self.parameters()).device),
|
||||||
self.ap,
|
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
language_id=aux_inputs["language_id"],
|
language_id=aux_inputs["language_id"],
|
||||||
language_name=aux_inputs["language_name"],
|
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
|
||||||
use_griffin_lim=True,
|
use_griffin_lim=True,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
).values()
|
).values()
|
||||||
|
|
|
@ -148,7 +148,7 @@ class TTSTokenizer:
|
||||||
# init cleaners
|
# init cleaners
|
||||||
text_cleaner = None
|
text_cleaner = None
|
||||||
if isinstance(config.text_cleaner, (str, list)):
|
if isinstance(config.text_cleaner, (str, list)):
|
||||||
text_cleaner = getattr(config, "text_cleaner")
|
text_cleaner = getattr(cleaners, config.text_cleaner)
|
||||||
|
|
||||||
# init characters
|
# init characters
|
||||||
if characters is None:
|
if characters is None:
|
||||||
|
|
|
@ -122,13 +122,9 @@ class Synthesizer(object):
|
||||||
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||||
|
|
||||||
if language_manager is not None:
|
if language_manager is not None:
|
||||||
self.tts_model = setup_tts_model(
|
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||||
config=self.tts_config,
|
|
||||||
speaker_manager=speaker_manager,
|
|
||||||
language_manager=language_manager,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
@ -333,7 +329,6 @@ class Synthesizer(object):
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
language_id=language_id,
|
language_id=language_id,
|
||||||
language_name=language_name,
|
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
@ -11,7 +9,9 @@ from tests import get_tests_data_path, get_tests_input_path, get_tests_output_pa
|
||||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.models.glow_tts import GlowTTS
|
from TTS.tts.models.glow_tts import GlowTTS
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
@ -31,7 +31,8 @@ def count_parameters(model):
|
||||||
|
|
||||||
|
|
||||||
class TestGlowTTS(unittest.TestCase):
|
class TestGlowTTS(unittest.TestCase):
|
||||||
def _create_inputs(self):
|
@staticmethod
|
||||||
|
def _create_inputs():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
|
@ -40,7 +41,8 @@ class TestGlowTTS(unittest.TestCase):
|
||||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
|
||||||
def _check_parameter_changes(self, model, model_ref):
|
@staticmethod
|
||||||
|
def _check_parameter_changes(model, model_ref):
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
@ -166,7 +168,7 @@ class TestGlowTTS(unittest.TestCase):
|
||||||
|
|
||||||
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
|
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
|
||||||
output_shape = outputs["model_outputs"].shape
|
output_shape = outputs["model_outputs"].shape
|
||||||
self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2])
|
self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2])
|
||||||
self.assertEqual(outputs["logdet"], None)
|
self.assertEqual(outputs["logdet"], None)
|
||||||
self.assertEqual(outputs["y_mean"].shape, output_shape)
|
self.assertEqual(outputs["y_mean"].shape, output_shape)
|
||||||
self.assertEqual(outputs["y_log_scale"].shape, output_shape)
|
self.assertEqual(outputs["y_log_scale"].shape, output_shape)
|
||||||
|
@ -185,7 +187,12 @@ class TestGlowTTS(unittest.TestCase):
|
||||||
def test_inference_with_d_vector(self):
|
def test_inference_with_d_vector(self):
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||||
d_vector = torch.rand(8, 256).to(device)
|
d_vector = torch.rand(8, 256).to(device)
|
||||||
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
|
config = GlowTTSConfig(
|
||||||
|
num_chars=32,
|
||||||
|
use_d_vector_file=True,
|
||||||
|
d_vector_dim=256,
|
||||||
|
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||||
|
)
|
||||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
|
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
|
||||||
|
@ -268,7 +275,9 @@ class TestGlowTTS(unittest.TestCase):
|
||||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||||
model.run_data_dep_init = False
|
model.run_data_dep_init = False
|
||||||
model.train()
|
model.train()
|
||||||
logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log")
|
logger = TensorboardLogger(
|
||||||
|
log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log"
|
||||||
|
)
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion()
|
||||||
outputs, _ = model.train_step(batch, criterion)
|
outputs, _ = model.train_step(batch, criterion)
|
||||||
model.train_log(batch, outputs, logger, None, 1)
|
model.train_log(batch, outputs, logger, None, 1)
|
||||||
|
@ -316,14 +325,23 @@ class TestGlowTTS(unittest.TestCase):
|
||||||
self.assertTrue(model.num_speakers == 2)
|
self.assertTrue(model.num_speakers == 2)
|
||||||
self.assertTrue(hasattr(model, "emb_g"))
|
self.assertTrue(hasattr(model, "emb_g"))
|
||||||
|
|
||||||
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"))
|
config = GlowTTSConfig(
|
||||||
|
num_chars=32,
|
||||||
|
num_speakers=2,
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
|
||||||
|
)
|
||||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||||
self.assertTrue(model.num_speakers == 10)
|
self.assertTrue(model.num_speakers == 10)
|
||||||
self.assertTrue(hasattr(model, "emb_g"))
|
self.assertTrue(hasattr(model, "emb_g"))
|
||||||
|
|
||||||
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
|
config = GlowTTSConfig(
|
||||||
|
num_chars=32,
|
||||||
|
use_d_vector_file=True,
|
||||||
|
d_vector_dim=256,
|
||||||
|
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||||
|
)
|
||||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||||
self.assertTrue(model.num_speakers == 1)
|
self.assertTrue(model.num_speakers == 1)
|
||||||
self.assertTrue(not hasattr(model, "emb_g"))
|
self.assertTrue(not hasattr(model, "emb_g"))
|
||||||
self.assertTrue(model.c_in_channels == config.d_vector_dim)
|
self.assertTrue(model.c_in_channels == config.d_vector_dim)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue