Update imports for symbols -> characters

This commit is contained in:
Eren Gölge 2021-11-17 12:46:04 +01:00
parent 9a95e15483
commit 2d8ce98d2a
2 changed files with 6 additions and 12 deletions

View File

@ -14,7 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
from TTS.tts.utils.text.characters import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
@ -33,7 +33,9 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
):
super().__init__(config)
self.config = config
self.ap = ap
@ -71,9 +73,6 @@ class BaseTTS(BaseModel):
else:
raise ValueError("config must be either a *Config or *Args")
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
@ -291,7 +290,7 @@ class BaseTTS(BaseModel):
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer
tokenizer=self.tokenizer,
)
# pre-compute phonemes

View File

@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀