diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 84406b40..2e369e77 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -280,7 +280,7 @@ If you don't specify any models, then it uses LJSpeech based English model. print( " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - print(synthesizer.tts_model.language_manager.language_id_mapping) + print(synthesizer.tts_model.language_manager.ids) return # check the arguments against a multi-speaker model. diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index df8e221d..5e4094e0 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -141,13 +141,13 @@ class BaseTTS(BaseTrainerModel): d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() + speaker_id = self.speaker_manager.get_random_id() else: speaker_id = self.speaker_manager.ids[speaker_name] # get language id if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] + language_id = self.language_manager.ids[language_name] return { "text": text, @@ -294,7 +294,7 @@ class BaseTTS(BaseTrainerModel): # setup multi-lingual attributes if hasattr(self, "language_manager") and self.language_manager is not None: language_id_mapping = ( - self.language_manager.language_id_mapping if self.args.use_language_embedding else None + self.language_manager.ids if self.args.use_language_embedding else None ) else: language_id_mapping = None @@ -416,7 +416,7 @@ class BaseTTS(BaseTrainerModel): if hasattr(self, "language_manager") and self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") - self.language_manager.save_language_ids_to_file(output_path) + self.language_manager.save_ids_to_file(output_path) trainer.config.language_ids_file = output_path if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 793559ac..156c20d8 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1220,13 +1220,13 @@ class Vits(BaseTTS): d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) elif config.use_speaker_embedding: if speaker_name is None: - speaker_id = self.speaker_manager.get_random_speaker_id() + speaker_id = self.speaker_manager.get_random_id() else: speaker_id = self.speaker_manager.ids[speaker_name] # get language id if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: - language_id = self.language_manager.language_id_mapping[language_name] + language_id = self.language_manager.ids[language_name] return { "text": text, @@ -1297,10 +1297,10 @@ class Vits(BaseTTS): # get language ids from language names if ( self.language_manager is not None - and self.language_manager.language_id_mapping + and self.language_manager.ids and self.args.use_language_embedding ): - language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] if language_ids is not None: language_ids = torch.LongTensor(language_ids) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 7decabb0..ff1df767 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,6 +1,5 @@ -import json import os -from typing import Dict, List +from typing import Dict, List, Any import fsspec import numpy as np @@ -8,9 +7,9 @@ import torch from coqpit import Coqpit from TTS.config import check_config_and_model_args +from TTS.tts.utils.managers import BaseIDManager - -class LanguageManager: +class LanguageManager(BaseIDManager): """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information in a way that can be queried by language. @@ -25,37 +24,23 @@ class LanguageManager: >>> language_id_mapper = manager.language_ids """ - language_id_mapping: Dict = {} - def __init__( self, language_ids_file_path: str = "", config: Coqpit = None, ): - self.language_id_mapping = {} - if language_ids_file_path: - self.set_language_ids_from_file(language_ids_file_path) + super().__init__(id_file_path=language_ids_file_path) if config: self.set_language_ids_from_config(config) - @staticmethod - def _load_json(json_file_path: str) -> Dict: - with fsspec.open(json_file_path, "r") as f: - return json.load(f) - - @staticmethod - def _save_json(json_file_path: str, data: dict) -> None: - with fsspec.open(json_file_path, "w") as f: - json.dump(data, f, indent=4) - @property def num_languages(self) -> int: - return len(list(self.language_id_mapping.keys())) + return len(list(self.ids.keys())) @property def language_names(self) -> List: - return list(self.language_id_mapping.keys()) + return list(self.ids.keys()) @staticmethod def parse_language_ids_from_config(c: Coqpit) -> Dict: @@ -81,23 +66,22 @@ class LanguageManager: Args: items (List): Data sampled returned by `load_meta_data()`. """ - self.language_id_mapping = self.parse_language_ids_from_config(c) + self.ids = self.parse_language_ids_from_config(c) - def set_language_ids_from_file(self, file_path: str) -> None: - """Load language ids from a json file. + @staticmethod + def parse_ids_from_data(items: list) -> Any: + raise NotImplementedError - Args: - file_path (str): Path to the target json file. - """ - self.language_id_mapping = self._load_json(file_path) + def set_ids_from_data(self, items: List) -> Any: + raise NotImplementedError - def save_language_ids_to_file(self, file_path: str) -> None: + def save_ids_to_file(self, file_path: str) -> None: """Save language IDs to a json file. Args: file_path (str): Path to the output file. """ - self._save_json(file_path, self.language_id_mapping) + self._save_json(file_path, self.ids) @staticmethod def init_from_config(config: Coqpit) -> "LanguageManager": diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 7242a31d..480c4c98 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -44,7 +44,7 @@ class BaseIDManager: self.ids, _ = self.parse_ids_from_data(items) def set_ids_from_file(self, file_path: str) -> None: - """Set speaker IDs from a file. + """Set IDs from a file. Args: file_path (str): Path to the file. @@ -52,14 +52,14 @@ class BaseIDManager: self.ids = self._load_json(file_path) def save_ids_to_file(self, file_path: str) -> None: - """Save speaker IDs to a json file. + """Save IDs to a json file. Args: file_path (str): Path to the output file. """ self._save_json(file_path, self.ids) - def get_random_speaker_id(self) -> Any: + def get_random_id(self) -> Any: """Get a random embedding. Args: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 4813fd41..5e2ecc72 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -240,7 +240,7 @@ class Synthesizer(object): hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None ): if language_name and isinstance(language_name, str): - language_id = self.tts_model.language_manager.language_id_mapping[language_name] + language_id = self.tts_model.language_manager.ids[language_name] elif not language_name: raise ValueError(