coqui-tts/TTS/tts/utils/languages.py

137 lines
5.0 KiB
Python

import os
import json
import torch
import fsspec
import numpy as np
from typing import Dict, Tuple, List
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
Args:
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
TTS models. Defaults to "".
Examples:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
>>> language_id_mapper = manager.language_ids
"""
language_id_mapping: Dict = {}
def __init__(
self,
language_id_file_path: str = "",
):
if language_id_file_path:
self.set_language_ids_from_file(language_id_file_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_languages(self) -> int:
return len(list(self.language_id_mapping.keys()))
@property
def language_names(self) -> List:
return list(self.language_id_mapping.keys())
@staticmethod
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
"""Parse language IDs from data samples retured by `load_meta_data()`.
Args:
items (list): Data sampled returned by `load_meta_data()`.
Returns:
Tuple[Dict, int]: language IDs and number of languages.
"""
languages = sorted({item[3] for item in items})
language_ids = {name: i for i, name in enumerate(languages)}
num_languages = len(language_ids)
return language_ids, num_languages
def set_language_ids_from_data(self, items: List) -> None:
"""Set language IDs from data samples.
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.language_id_mapping, _ = self.parse_languages_from_data(items)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.language_id_mapping = self._load_json(file_path)
def save_language_ids_to_file(self, file_path: str) -> None:
"""Save language IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.language_id_mapping)
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
path_continue = os.path.join(path, "language_ids.json")
fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore
if fs.exists(path_continue):
return path_continue
return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config.
Args:
c (Coqpit): Model configuration.
restore_path (str): Path to a previous training folder.
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
Returns:
SpeakerManager: initialized and ready to use instance.
"""
language_manager = LanguageManager()
if c.use_language_embedding:
if data is not None:
language_manager.set_language_ids_from_data(data)
if restore_path:
language_file = _set_file_path(restore_path)
# restoring language manager from a previous run.
if language_file:
language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0:
print(
" > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names)
)
)
return language_manager
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1. / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))