mirror of https://github.com/coqui-ai/TTS.git
Refactor GlowTTS model and recipe for TTSTokenizer
This commit is contained in:
parent
5a9653978a
commit
bd461ace33
|
@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC):
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._set_model_args(config)
|
|
||||||
|
|
||||||
def _set_model_args(self, config: Coqpit):
|
@staticmethod
|
||||||
"""Set model arguments from the config. Override this."""
|
def init_from_config(config: Coqpit):
|
||||||
|
"""Init the model from given config.
|
||||||
|
|
||||||
|
Override this depending on your model.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
@ -34,8 +34,20 @@ class BaseTTS(BaseModel):
|
||||||
- 1D tensors `batch x 1`
|
- 1D tensors `batch x 1`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.ap = ap
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.speaker_manager = speaker_manager
|
||||||
|
self._set_model_args(config)
|
||||||
|
|
||||||
def _set_model_args(self, config: Coqpit):
|
def _set_model_args(self, config: Coqpit):
|
||||||
"""Setup model args based on the config type.
|
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||||
|
|
||||||
|
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||||
|
|
||||||
|
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||||
|
|
||||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||||
config.model_args
|
config.model_args
|
||||||
|
@ -44,8 +56,8 @@ class BaseTTS(BaseModel):
|
||||||
"""
|
"""
|
||||||
# don't use isintance not to import recursively
|
# don't use isintance not to import recursively
|
||||||
if "Config" in config.__class__.__name__:
|
if "Config" in config.__class__.__name__:
|
||||||
|
num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||||
if "characters" in config:
|
if "characters" in config:
|
||||||
_, self.config, num_chars = self.get_characters(config)
|
|
||||||
self.config.num_chars = num_chars
|
self.config.num_chars = num_chars
|
||||||
if hasattr(self.config, "model_args"):
|
if hasattr(self.config, "model_args"):
|
||||||
config.model_args.num_chars = num_chars
|
config.model_args.num_chars = num_chars
|
||||||
|
@ -58,18 +70,21 @@ class BaseTTS(BaseModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("config must be either a *Config or *Args")
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def get_characters(config: Coqpit) -> str:
|
# def get_characters(config: Coqpit) -> str:
|
||||||
# TODO: implement CharacterProcessor
|
# # TODO: implement CharacterProcessor
|
||||||
if config.characters is not None:
|
# if config.characters is not None:
|
||||||
symbols, phonemes = make_symbols(**config.characters)
|
# symbols, phonemes = make_symbols(**config.characters)
|
||||||
else:
|
# else:
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
# from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||||
|
|
||||||
config.characters = CharactersConfig(**parse_symbols())
|
# if config.use_phonemes:
|
||||||
model_characters = phonemes if config.use_phonemes else symbols
|
|
||||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
# config.characters = Graphemes()
|
||||||
return model_characters, config, num_chars
|
|
||||||
|
# model_characters = phonemes if config.use_phonemes else symbols
|
||||||
|
# num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||||
|
# return model_characters, config, num_chars
|
||||||
|
|
||||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
return get_speaker_manager(config, restore_path, data, out_path)
|
||||||
|
@ -247,8 +262,6 @@ class BaseTTS(BaseModel):
|
||||||
if is_eval and not config.run_eval:
|
if is_eval and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
ap = assets["audio_processor"]
|
|
||||||
|
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
|
@ -279,28 +292,21 @@ class BaseTTS(BaseModel):
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
|
||||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||||
compute_f0=config.get("compute_f0", False),
|
compute_f0=config.get("compute_f0", False),
|
||||||
f0_cache_path=config.get("f0_cache_path", None),
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
meta_data=data_items,
|
meta_data=data_items,
|
||||||
ap=ap,
|
ap=self.ap,
|
||||||
characters=config.characters,
|
|
||||||
custom_symbols=custom_symbols,
|
|
||||||
add_blank=config["add_blank"],
|
|
||||||
return_wav=config.return_wav if "return_wav" in config else False,
|
return_wav=config.return_wav if "return_wav" in config else False,
|
||||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
min_seq_len=config.min_seq_len,
|
min_seq_len=config.min_seq_len,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
phoneme_cache_path=config.phoneme_cache_path,
|
phoneme_cache_path=config.phoneme_cache_path,
|
||||||
use_phonemes=config.use_phonemes,
|
|
||||||
phoneme_language=config.phoneme_language,
|
|
||||||
enable_eos_bos=config.enable_eos_bos_chars,
|
|
||||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping,
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||||
language_id_mapping=language_id_mapping,
|
tokenizer=self.tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
# pre-compute phonemes
|
# pre-compute phonemes
|
||||||
|
@ -332,7 +338,7 @@ class BaseTTS(BaseModel):
|
||||||
if config.compute_f0 and rank in [None, 0]:
|
if config.compute_f0 and rank in [None, 0]:
|
||||||
if not os.path.exists(config.f0_cache_path):
|
if not os.path.exists(config.f0_cache_path):
|
||||||
dataset.pitch_extractor.compute_pitch(
|
dataset.pitch_extractor.compute_pitch(
|
||||||
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
# halt DDP processes for the main process to finish computing the F0 cache
|
# halt DDP processes for the main process to finish computing the F0 cache
|
||||||
|
@ -404,6 +410,7 @@ class BaseTTS(BaseModel):
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
|
tokenizer = assets["tokenizer"]
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
@ -416,6 +423,7 @@ class BaseTTS(BaseModel):
|
||||||
self.config,
|
self.config,
|
||||||
"cuda" in str(next(self.parameters()).device),
|
"cuda" in str(next(self.parameters()).device),
|
||||||
ap,
|
ap,
|
||||||
|
tokenizer,
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
|
|
@ -46,11 +46,9 @@ class GlowTTS(BaseTTS):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
|
def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
self.speaker_manager = speaker_manager
|
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
# for fewer code change
|
# for fewer code change
|
||||||
|
@ -58,7 +56,7 @@ class GlowTTS(BaseTTS):
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
_, self.config, self.num_chars = self.get_characters(config)
|
self.num_chars = self.tokenizer.characters.num_chars
|
||||||
self.decoder_output_dim = config.out_channels
|
self.decoder_output_dim = config.out_channels
|
||||||
|
|
||||||
# init multi-speaker layers if necessary
|
# init multi-speaker layers if necessary
|
||||||
|
@ -448,7 +446,6 @@ class GlowTTS(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
ap = assets["audio_processor"]
|
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
@ -463,7 +460,8 @@ class GlowTTS(BaseTTS):
|
||||||
sen,
|
sen,
|
||||||
self.config,
|
self.config,
|
||||||
"cuda" in str(next(self.parameters()).device),
|
"cuda" in str(next(self.parameters()).device),
|
||||||
ap,
|
self.ap,
|
||||||
|
self.tokenizer,
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
|
|
@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.glow_tts import GlowTTS
|
from TTS.tts.models.glow_tts import GlowTTS
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
# we use the same path as this script as our training folder.
|
# we use the same path as this script as our training folder.
|
||||||
|
@ -47,7 +48,11 @@ config = GlowTTSConfig(
|
||||||
# INITIALIZE THE AUDIO PROCESSOR
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
# Audio processor is used for feature extraction and audio I/O.
|
# Audio processor is used for feature extraction and audio I/O.
|
||||||
# It mainly serves to the dataloader and the training loggers.
|
# It mainly serves to the dataloader and the training loggers.
|
||||||
ap = AudioProcessor(**config.audio.to_dict())
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
|
||||||
|
# INITIALIZE THE TOKENIZER
|
||||||
|
# Tokenizer is used to convert text to sequences of token IDs.
|
||||||
|
tokenizer = TTSTokenizer.init_from_config(config)
|
||||||
|
|
||||||
# LOAD DATA SAMPLES
|
# LOAD DATA SAMPLES
|
||||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||||
|
@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
# Models take a config object and a speaker manager as input
|
# Models take a config object and a speaker manager as input
|
||||||
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
|
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
|
||||||
# Speaker manager is used by multi-speaker models.
|
# Speaker manager is used by multi-speaker models.
|
||||||
model = GlowTTS(config, speaker_manager=None)
|
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
|
||||||
|
|
||||||
# INITIALIZE THE TRAINER
|
# INITIALIZE THE TRAINER
|
||||||
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
|
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
|
||||||
|
@ -71,8 +76,7 @@ trainer = Trainer(
|
||||||
output_path,
|
output_path,
|
||||||
model=model,
|
model=model,
|
||||||
train_samples=train_samples,
|
train_samples=train_samples,
|
||||||
eval_samples=eval_samples,
|
eval_samples=eval_samples
|
||||||
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# AND... 3,2,1... 🚀
|
# AND... 3,2,1... 🚀
|
||||||
|
|
Loading…
Reference in New Issue