mirror of https://github.com/coqui-ai/TTS.git
Move multilingual logic out of the trainer
This commit is contained in:
parent
818dc4ccd8
commit
6b03943526
|
@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -60,8 +61,17 @@ def main():
|
|||
else:
|
||||
speaker_manager = None
|
||||
|
||||
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
|
||||
language_manager = LanguageManager(config=config)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_languages = language_manager.num_languages
|
||||
else:
|
||||
config.num_languages = language_manager.num_languages
|
||||
else:
|
||||
language_manager = None
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config, speaker_manager)
|
||||
model = setup_model(config, speaker_manager, language_manager)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
|
|
|
@ -260,22 +260,6 @@ class Trainer:
|
|||
else:
|
||||
self.run_get_model(self.config, get_model)
|
||||
|
||||
if hasattr(self.model, "init_multilingual"):
|
||||
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
|
||||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||
# save speakers json
|
||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||
self.model.language_manager.save_language_ids_to_file(
|
||||
os.path.join(self.output_path, "language_ids.json")
|
||||
)
|
||||
if hasattr(self.config, "model_args"):
|
||||
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
||||
else:
|
||||
self.config.num_languages = self.model.language_manager.num_languages
|
||||
|
||||
# update config file
|
||||
copy_model_files(self.config, self.output_path)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
||||
|
|
|
@ -85,6 +85,12 @@ class VitsConfig(BaseTTSConfig):
|
|||
test_sentences (List[List]):
|
||||
List of sentences with speaker and language information to be used for testing.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language ids file.
|
||||
|
||||
use_language_embedding (bool):
|
||||
If true, language embedding is used. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
|
@ -147,6 +153,8 @@ class VitsConfig(BaseTTSConfig):
|
|||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
language_ids_file: str = None
|
||||
use_language_embedding: bool = False
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
|
|
|
@ -2,7 +2,11 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
|||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
||||
def setup_model(
|
||||
config,
|
||||
speaker_manager: "SpeakerManager" = None,
|
||||
language_manager: "LanguageManager" = None
|
||||
):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
|
@ -31,7 +35,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
|||
config.model_params.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
config.model_args.num_chars = num_chars
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||
else:
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -419,8 +419,7 @@ class BaseTTS(BaseModel):
|
|||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
||||
speakers.json file path."""
|
||||
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||
|
@ -431,3 +430,13 @@ class BaseTTS(BaseModel):
|
|||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.json` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_language_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
|
|
|
@ -16,8 +16,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
|||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -158,6 +158,9 @@ class VitsArgs(Coqpit):
|
|||
num_languages (int):
|
||||
Number of languages for the language embedding layer. Defaults to 0.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language mapping file for the Language Manager. Defaults to None.
|
||||
|
||||
use_speaker_encoder_as_loss (bool):
|
||||
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||
|
||||
|
@ -225,6 +228,7 @@ class VitsArgs(Coqpit):
|
|||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
num_languages: int = 0
|
||||
language_ids_file: str = None
|
||||
use_speaker_encoder_as_loss: bool = False
|
||||
speaker_encoder_config_path: str = ""
|
||||
speaker_encoder_model_path: str = ""
|
||||
|
@ -265,13 +269,18 @@ class Vits(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.END2END = True
|
||||
self.speaker_manager = speaker_manager
|
||||
self.audio_config = config["audio"]
|
||||
self.language_manager = language_manager
|
||||
if config.__class__.__name__ == "VitsConfig":
|
||||
# loading from VitsConfig
|
||||
if "num_chars" not in config:
|
||||
|
@ -443,43 +452,20 @@ class Vits(BaseTTS):
|
|||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||
for param in self.speaker_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
||||
else:
|
||||
self.audio_transform = None
|
||||
else:
|
||||
self.audio_transform = None
|
||||
self.speaker_encoder = None
|
||||
|
||||
def init_multilingual(self, config: Coqpit, data: List = None):
|
||||
def init_multilingual(self, config: Coqpit):
|
||||
"""Initialize multilingual modules of a model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
# init language manager
|
||||
self.language_manager = LanguageManager(config, data=data)
|
||||
|
||||
# init language embedding layer
|
||||
if config.use_language_embedding:
|
||||
if config.num_languages > 0 and self.language_manager.num_languages == 0:
|
||||
self.num_languages = config.num_languages
|
||||
else:
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
if config.language_ids_file is not None:
|
||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
|
||||
if config.use_language_embedding and self.language_manager:
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
self.embedded_language_dim = config.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -14,11 +14,13 @@ class LanguageManager:
|
|||
in a way that can be queried by language.
|
||||
|
||||
Args:
|
||||
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||
TTS models. Defaults to "".
|
||||
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
||||
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
|
||||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
|
||||
|
@ -26,10 +28,14 @@ class LanguageManager:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
language_id_file_path: str = "",
|
||||
language_ids_file_path: str = "",
|
||||
config: Coqpit = None,
|
||||
):
|
||||
if language_id_file_path:
|
||||
self.set_language_ids_from_file(language_id_file_path)
|
||||
if language_ids_file_path:
|
||||
self.set_language_ids_from_file(language_ids_file_path)
|
||||
|
||||
if config:
|
||||
self.set_language_ids_from_config(config)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
|
@ -50,27 +56,30 @@ class LanguageManager:
|
|||
return list(self.language_id_mapping.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse language IDs from data samples retured by `load_meta_data()`.
|
||||
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||
"""Set language id from config.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_meta_data()`.
|
||||
c (Coqpit): Config
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: language IDs and number of languages.
|
||||
Tuple[Dict, int]: Language ID mapping and the number of languages.
|
||||
"""
|
||||
languages = sorted({item[3] for item in items})
|
||||
language_ids = {name: i for i, name in enumerate(languages)}
|
||||
num_languages = len(language_ids)
|
||||
return language_ids, num_languages
|
||||
languages = set({})
|
||||
for dataset in c.datasets:
|
||||
if "language" in dataset:
|
||||
languages.add(dataset["language"])
|
||||
else:
|
||||
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
||||
|
||||
def set_language_ids_from_data(self, items: List) -> None:
|
||||
"""Set language IDs from data samples.
|
||||
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||
"""Set language IDs from config samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
"""
|
||||
self.language_id_mapping, _ = self.parse_languages_from_data(items)
|
||||
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||
|
||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||
"""Load language ids from a json file.
|
||||
|
@ -102,36 +111,6 @@ def _set_file_path(path):
|
|||
return None
|
||||
|
||||
|
||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
||||
"""Initiate a `LanguageManager` instance by the provided config.
|
||||
|
||||
Args:
|
||||
c (Coqpit): Model configuration.
|
||||
restore_path (str): Path to a previous training folder.
|
||||
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
|
||||
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
SpeakerManager: initialized and ready to use instance.
|
||||
"""
|
||||
language_manager = LanguageManager()
|
||||
if c.use_language_embedding:
|
||||
if data is not None:
|
||||
language_manager.set_language_ids_from_data(data)
|
||||
if restore_path:
|
||||
language_file = _set_file_path(restore_path)
|
||||
# restoring language manager from a previous run.
|
||||
if language_file:
|
||||
language_manager.set_language_ids_from_file(language_file)
|
||||
if language_manager.num_languages > 0:
|
||||
print(
|
||||
" > Language manager is loaded with {} languages: {}".format(
|
||||
language_manager.num_languages, ", ".join(language_manager.language_names)
|
||||
)
|
||||
)
|
||||
return language_manager
|
||||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
|
|
Loading…
Reference in New Issue