mirror of https://github.com/coqui-ai/TTS.git
Implement LanguageManager inherit BaseIDManager
This commit is contained in:
parent
eac06a5e87
commit
e33819b7de
|
@ -280,7 +280,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||
print(synthesizer.tts_model.language_manager.ids)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
|
|
|
@ -141,13 +141,13 @@ class BaseTTS(BaseTrainerModel):
|
|||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.language_manager.ids[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
|
@ -294,7 +294,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
self.language_manager.ids if self.args.use_language_embedding else None
|
||||
)
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
@ -416,7 +416,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_language_ids_to_file(output_path)
|
||||
self.language_manager.save_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
|
|
|
@ -1220,13 +1220,13 @@ class Vits(BaseTTS):
|
|||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.language_manager.ids[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
|
@ -1297,10 +1297,10 @@ class Vits(BaseTTS):
|
|||
# get language ids from language names
|
||||
if (
|
||||
self.language_manager is not None
|
||||
and self.language_manager.language_id_mapping
|
||||
and self.language_manager.ids
|
||||
and self.args.use_language_embedding
|
||||
):
|
||||
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
|
||||
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
|
||||
|
||||
if language_ids is not None:
|
||||
language_ids = torch.LongTensor(language_ids)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -8,9 +7,9 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import check_config_and_model_args
|
||||
from TTS.tts.utils.managers import BaseIDManager
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
class LanguageManager(BaseIDManager):
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
||||
|
@ -25,37 +24,23 @@ class LanguageManager:
|
|||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
|
||||
language_id_mapping: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language_ids_file_path: str = "",
|
||||
config: Coqpit = None,
|
||||
):
|
||||
self.language_id_mapping = {}
|
||||
if language_ids_file_path:
|
||||
self.set_language_ids_from_file(language_ids_file_path)
|
||||
super().__init__(id_file_path=language_ids_file_path)
|
||||
|
||||
if config:
|
||||
self.set_language_ids_from_config(config)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
def num_languages(self) -> int:
|
||||
return len(list(self.language_id_mapping.keys()))
|
||||
return len(list(self.ids.keys()))
|
||||
|
||||
@property
|
||||
def language_names(self) -> List:
|
||||
return list(self.language_id_mapping.keys())
|
||||
return list(self.ids.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||
|
@ -81,23 +66,22 @@ class LanguageManager:
|
|||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
"""
|
||||
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||
self.ids = self.parse_language_ids_from_config(c)
|
||||
|
||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||
"""Load language ids from a json file.
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.language_id_mapping = self._load_json(file_path)
|
||||
def set_ids_from_data(self, items: List) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save language IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit) -> "LanguageManager":
|
||||
|
|
|
@ -44,7 +44,7 @@ class BaseIDManager:
|
|||
self.ids, _ = self.parse_ids_from_data(items)
|
||||
|
||||
def set_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set speaker IDs from a file.
|
||||
"""Set IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
|
@ -52,14 +52,14 @@ class BaseIDManager:
|
|||
self.ids = self._load_json(file_path)
|
||||
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save speaker IDs to a json file.
|
||||
"""Save IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.ids)
|
||||
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
def get_random_id(self) -> Any:
|
||||
"""Get a random embedding.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -240,7 +240,7 @@ class Synthesizer(object):
|
|||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||
):
|
||||
if language_name and isinstance(language_name, str):
|
||||
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
|
||||
language_id = self.tts_model.language_manager.ids[language_name]
|
||||
|
||||
elif not language_name:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Reference in New Issue