Update imports for symbols -> characters

This commit is contained in:
Eren Gölge 2021-11-17 12:46:04 +01:00
parent 9a95e15483
commit 2d8ce98d2a
2 changed files with 6 additions and 12 deletions

View File

@ -14,7 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import Graphemes, make_symbols from TTS.tts.utils.text.characters import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file # pylint: skip-file
@ -33,7 +33,9 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1` - 1D tensors `batch x 1`
""" """
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.ap = ap self.ap = ap
@ -71,9 +73,6 @@ class BaseTTS(BaseModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers. `in_channels` size of the connected layers.
@ -291,7 +290,7 @@ class BaseTTS(BaseModel):
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer tokenizer=self.tokenizer,
) )
# pre-compute phonemes # pre-compute phonemes

View File

@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc. # distributed training, etc.
trainer = Trainer( trainer = Trainer(
TrainingArgs(), TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples
) )
# AND... 3,2,1... 🚀 # AND... 3,2,1... 🚀