diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index a543a947..5330649a 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager from TTS.utils.audio import AudioProcessor @@ -60,8 +61,17 @@ def main(): else: speaker_manager = None + if hasattr(config, "use_language_embedding") and config.use_language_embedding: + language_manager = LanguageManager(config=config) + if hasattr(config, "model_args"): + config.model_args.num_languages = language_manager.num_languages + else: + config.num_languages = language_manager.num_languages + else: + language_manager = None + # init the model from config - model = setup_model(config, speaker_manager) + model = setup_model(config, speaker_manager, language_manager) # init the trainer and 🚀 trainer = Trainer( diff --git a/TTS/trainer.py b/TTS/trainer.py index b9026c8e..7bffb386 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -260,22 +260,6 @@ class Trainer: else: self.run_get_model(self.config, get_model) - if hasattr(self.model, "init_multilingual"): - self.model.init_multilingual(self.config, self.train_samples + self.eval_samples) - config = self.config.model_args if hasattr(self.config, "model_args") else self.config - # save speakers json - if config.use_language_embedding and self.model.language_manager.num_languages > 1: - self.model.language_manager.save_language_ids_to_file( - os.path.join(self.output_path, "language_ids.json") - ) - if hasattr(self.config, "model_args"): - self.config.model_args["num_languages"] = self.model.language_manager.num_languages - else: - self.config.num_languages = self.model.language_manager.num_languages - - # update config file - copy_model_files(self.config, self.output_path) - # setup criterion self.criterion = self.get_criterion(self.model) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 178992a7..32a69bca 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -85,6 +85,12 @@ class VitsConfig(BaseTTSConfig): test_sentences (List[List]): List of sentences with speaker and language information to be used for testing. + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + Note: Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. @@ -147,6 +153,8 @@ class VitsConfig(BaseTTSConfig): use_speaker_embedding: bool = False speakers_file: str = None speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False # use d-vectors use_d_vector_file: bool = False diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 780f22cd..acd89110 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -2,7 +2,11 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model(config, speaker_manager: "SpeakerManager" = None): +def setup_model( + config, + speaker_manager: "SpeakerManager" = None, + language_manager: "LanguageManager" = None + ): print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: @@ -31,7 +35,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None): config.model_params.num_chars = num_chars if "model_args" in config: config.model_args.num_chars = num_chars - model = MyModel(config, speaker_manager=speaker_manager) + if config.model.lower() in ["vits"]: # If model supports multiple languages + model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) + else: + model = MyModel(config, speaker_manager=speaker_manager) return model diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 707fc9c3..14bc9180 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -419,8 +419,7 @@ class BaseTTS(BaseModel): return test_figures, test_audios def on_init_start(self, trainer): - """Save the speaker.json at the beginning of the training. And update the config.json with the - speakers.json file path.""" + """Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.json") self.speaker_manager.save_speaker_ids_to_file(output_path) @@ -431,3 +430,13 @@ class BaseTTS(BaseModel): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `speakers.json` is saved to {output_path}.") print(" > `speakers_file` is updated in the config.json.") + + if hasattr(self, "language_manager") and self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_language_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c1c29980..ca110eb0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -16,8 +16,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler @@ -158,6 +158,9 @@ class VitsArgs(Coqpit): num_languages (int): Number of languages for the language embedding layer. Defaults to 0. + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + use_speaker_encoder_as_loss (bool): Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. @@ -225,6 +228,7 @@ class VitsArgs(Coqpit): use_language_embedding: bool = False embedded_language_dim: int = 4 num_languages: int = 0 + language_ids_file: str = None use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" @@ -265,13 +269,18 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + def __init__( + self, + config: Coqpit, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): super().__init__(config) self.END2END = True self.speaker_manager = speaker_manager - self.audio_config = config["audio"] + self.language_manager = language_manager if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig if "num_chars" not in config: @@ -443,43 +452,20 @@ class Vits(BaseTTS): self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.embedded_speaker_dim = config.d_vector_dim - if config.use_speaker_encoder_as_loss: - if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: - raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!") - self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path) - self.speaker_encoder = self.speaker_manager.speaker_encoder.train() - for param in self.speaker_encoder.parameters(): - param.requires_grad = False - - print(" > External Speaker Encoder Loaded !!") - - if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]: - self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"]) - else: - self.audio_transform = None - else: - self.audio_transform = None - self.speaker_encoder = None - - def init_multilingual(self, config: Coqpit, data: List = None): + def init_multilingual(self, config: Coqpit): """Initialize multilingual modules of a model. Args: config (Coqpit): Model configuration. - data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ if hasattr(config, "model_args"): config = config.model_args - # init language manager - self.language_manager = LanguageManager(config, data=data) - # init language embedding layer - if config.use_language_embedding: - if config.num_languages > 0 and self.language_manager.num_languages == 0: - self.num_languages = config.num_languages - else: - self.num_languages = self.language_manager.num_languages + if config.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + if config.use_language_embedding and self.language_manager: + self.num_languages = self.language_manager.num_languages self.embedded_language_dim = config.embedded_language_dim self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) torch.nn.init.xavier_uniform_(self.emb_l.weight) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 5bacc259..451b10f9 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,6 +1,6 @@ import json import os -from typing import Dict, List, Tuple +from typing import Dict, List import fsspec import numpy as np @@ -14,11 +14,13 @@ class LanguageManager: in a way that can be queried by language. Args: - language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by + language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by TTS models. Defaults to "". + config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed. + Defaults to None. Examples: - >>> manager = LanguageManager(language_id_file_path=language_id_file_path) + >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path) >>> language_id_mapper = manager.language_ids """ @@ -26,10 +28,14 @@ class LanguageManager: def __init__( self, - language_id_file_path: str = "", + language_ids_file_path: str = "", + config: Coqpit = None, ): - if language_id_file_path: - self.set_language_ids_from_file(language_id_file_path) + if language_ids_file_path: + self.set_language_ids_from_file(language_ids_file_path) + + if config: + self.set_language_ids_from_config(config) @staticmethod def _load_json(json_file_path: str) -> Dict: @@ -50,27 +56,30 @@ class LanguageManager: return list(self.language_id_mapping.keys()) @staticmethod - def parse_languages_from_data(items: list) -> Tuple[Dict, int]: - """Parse language IDs from data samples retured by `load_meta_data()`. + def parse_language_ids_from_config(c: Coqpit) -> Dict: + """Set language id from config. Args: - items (list): Data sampled returned by `load_meta_data()`. + c (Coqpit): Config Returns: - Tuple[Dict, int]: language IDs and number of languages. + Tuple[Dict, int]: Language ID mapping and the number of languages. """ - languages = sorted({item[3] for item in items}) - language_ids = {name: i for i, name in enumerate(languages)} - num_languages = len(language_ids) - return language_ids, num_languages + languages = set({}) + for dataset in c.datasets: + if "language" in dataset: + languages.add(dataset["language"]) + else: + raise ValueError(f"Dataset {dataset['name']} has no language specified.") + return {name: i for i, name in enumerate(sorted(list(languages)))} - def set_language_ids_from_data(self, items: List) -> None: - """Set language IDs from data samples. + def set_language_ids_from_config(self, c: Coqpit) -> None: + """Set language IDs from config samples. Args: items (List): Data sampled returned by `load_meta_data()`. """ - self.language_id_mapping, _ = self.parse_languages_from_data(items) + self.language_id_mapping = self.parse_language_ids_from_config(c) def set_language_ids_from_file(self, file_path: str) -> None: """Load language ids from a json file. @@ -102,36 +111,6 @@ def _set_file_path(path): return None -def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager: - """Initiate a `LanguageManager` instance by the provided config. - - Args: - c (Coqpit): Model configuration. - restore_path (str): Path to a previous training folder. - data (List): Data sampled returned by `load_meta_data()`. Defaults to None. - out_path (str, optional): Save the generated language IDs to a output path. Defaults to None. - - Returns: - SpeakerManager: initialized and ready to use instance. - """ - language_manager = LanguageManager() - if c.use_language_embedding: - if data is not None: - language_manager.set_language_ids_from_data(data) - if restore_path: - language_file = _set_file_path(restore_path) - # restoring language manager from a previous run. - if language_file: - language_manager.set_language_ids_from_file(language_file) - if language_manager.num_languages > 0: - print( - " > Language manager is loaded with {} languages: {}".format( - language_manager.num_languages, ", ".join(language_manager.language_names) - ) - ) - return language_manager - - def get_language_weighted_sampler(items: list): language_names = np.array([item[3] for item in items]) unique_language_names = np.unique(language_names).tolist()