Implement LanguageManager inherit BaseIDManager

This commit is contained in:
Edresson Casanova 2022-03-11 19:25:18 -03:00
parent 4fdc864f74
commit c7af7c6474
6 changed files with 27 additions and 43 deletions

View File

@ -286,7 +286,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.language_id_mapping)
print(synthesizer.tts_model.language_manager.ids)
return
# check the arguments against a multi-speaker model.

View File

@ -141,13 +141,13 @@ class BaseTTS(BaseTrainerModel):
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]
return {
"text": text,
@ -294,7 +294,7 @@ class BaseTTS(BaseTrainerModel):
# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
self.language_manager.ids if self.args.use_language_embedding else None
)
else:
language_id_mapping = None
@ -416,7 +416,7 @@ class BaseTTS(BaseTrainerModel):
if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path

View File

@ -1229,13 +1229,13 @@ class Vits(BaseTTS):
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]
return {
"text": text,
@ -1306,10 +1306,10 @@ class Vits(BaseTTS):
# get language ids from language names
if (
self.language_manager is not None
and self.language_manager.language_id_mapping
and self.language_manager.ids
and self.args.use_language_embedding
):
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)

View File

@ -1,6 +1,5 @@
import json
import os
from typing import Dict, List
from typing import Dict, List, Any
import fsspec
import numpy as np
@ -8,9 +7,9 @@ import torch
from coqpit import Coqpit
from TTS.config import check_config_and_model_args
from TTS.tts.utils.managers import BaseIDManager
class LanguageManager:
class LanguageManager(BaseIDManager):
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
@ -25,37 +24,23 @@ class LanguageManager:
>>> language_id_mapper = manager.language_ids
"""
language_id_mapping: Dict = {}
def __init__(
self,
language_ids_file_path: str = "",
config: Coqpit = None,
):
self.language_id_mapping = {}
if language_ids_file_path:
self.set_language_ids_from_file(language_ids_file_path)
super().__init__(id_file_path=language_ids_file_path)
if config:
self.set_language_ids_from_config(config)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_languages(self) -> int:
return len(list(self.language_id_mapping.keys()))
return len(list(self.ids.keys()))
@property
def language_names(self) -> List:
return list(self.language_id_mapping.keys())
return list(self.ids.keys())
@staticmethod
def parse_language_ids_from_config(c: Coqpit) -> Dict:
@ -81,23 +66,22 @@ class LanguageManager:
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.language_id_mapping = self.parse_language_ids_from_config(c)
self.ids = self.parse_language_ids_from_config(c)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError
Args:
file_path (str): Path to the target json file.
"""
self.language_id_mapping = self._load_json(file_path)
def set_ids_from_data(self, items: List) -> Any:
raise NotImplementedError
def save_language_ids_to_file(self, file_path: str) -> None:
def save_ids_to_file(self, file_path: str) -> None:
"""Save language IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.language_id_mapping)
self._save_json(file_path, self.ids)
@staticmethod
def init_from_config(config: Coqpit) -> "LanguageManager":

View File

@ -44,7 +44,7 @@ class BaseIDManager:
self.ids, _ = self.parse_ids_from_data(items)
def set_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
"""Set IDs from a file.
Args:
file_path (str): Path to the file.
@ -52,14 +52,14 @@ class BaseIDManager:
self.ids = self._load_json(file_path)
def save_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
"""Save IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.ids)
def get_random_speaker_id(self) -> Any:
def get_random_id(self) -> Any:
"""Get a random embedding.
Args:

View File

@ -242,7 +242,7 @@ class Synthesizer(object):
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_name and isinstance(language_name, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
language_id = self.tts_model.language_manager.ids[language_name]
elif not language_name:
raise ValueError(