mirror of https://github.com/coqui-ai/TTS.git
Move multilingual logic out of the trainer
This commit is contained in:
parent
b909a3b63e
commit
352b4be104
TTS
|
@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,8 +61,17 @@ def main():
|
||||||
else:
|
else:
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
|
|
||||||
|
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
|
||||||
|
language_manager = LanguageManager(config=config)
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
config.model_args.num_languages = language_manager.num_languages
|
||||||
|
else:
|
||||||
|
config.num_languages = language_manager.num_languages
|
||||||
|
else:
|
||||||
|
language_manager = None
|
||||||
|
|
||||||
# init the model from config
|
# init the model from config
|
||||||
model = setup_model(config, speaker_manager)
|
model = setup_model(config, speaker_manager, language_manager)
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
|
@ -260,22 +260,6 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.run_get_model(self.config, get_model)
|
self.run_get_model(self.config, get_model)
|
||||||
|
|
||||||
if hasattr(self.model, "init_multilingual"):
|
|
||||||
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
|
|
||||||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
|
||||||
# save speakers json
|
|
||||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
|
||||||
self.model.language_manager.save_language_ids_to_file(
|
|
||||||
os.path.join(self.output_path, "language_ids.json")
|
|
||||||
)
|
|
||||||
if hasattr(self.config, "model_args"):
|
|
||||||
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
|
||||||
else:
|
|
||||||
self.config.num_languages = self.model.language_manager.num_languages
|
|
||||||
|
|
||||||
# update config file
|
|
||||||
copy_model_files(self.config, self.output_path)
|
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,12 @@ class VitsConfig(BaseTTSConfig):
|
||||||
test_sentences (List[List]):
|
test_sentences (List[List]):
|
||||||
List of sentences with speaker and language information to be used for testing.
|
List of sentences with speaker and language information to be used for testing.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language ids file.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
If true, language embedding is used. Defaults to `False`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
@ -147,6 +153,8 @@ class VitsConfig(BaseTTSConfig):
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
speaker_embedding_channels: int = 256
|
speaker_embedding_channels: int = 256
|
||||||
|
language_ids_file: str = None
|
||||||
|
use_language_embedding: bool = False
|
||||||
|
|
||||||
# use d-vectors
|
# use d-vectors
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
|
|
|
@ -2,7 +2,11 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
|
|
||||||
def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
def setup_model(
|
||||||
|
config,
|
||||||
|
speaker_manager: "SpeakerManager" = None,
|
||||||
|
language_manager: "LanguageManager" = None
|
||||||
|
):
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
# fetch the right model implementation.
|
# fetch the right model implementation.
|
||||||
if "base_model" in config and config["base_model"] is not None:
|
if "base_model" in config and config["base_model"] is not None:
|
||||||
|
@ -31,6 +35,9 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
||||||
config.model_params.num_chars = num_chars
|
config.model_params.num_chars = num_chars
|
||||||
if "model_args" in config:
|
if "model_args" in config:
|
||||||
config.model_args.num_chars = num_chars
|
config.model_args.num_chars = num_chars
|
||||||
|
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||||
|
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||||
|
else:
|
||||||
model = MyModel(config, speaker_manager=speaker_manager)
|
model = MyModel(config, speaker_manager=speaker_manager)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -419,8 +419,7 @@ class BaseTTS(BaseModel):
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
||||||
def on_init_start(self, trainer):
|
def on_init_start(self, trainer):
|
||||||
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
|
||||||
speakers.json file path."""
|
|
||||||
if self.speaker_manager is not None:
|
if self.speaker_manager is not None:
|
||||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||||
|
@ -431,3 +430,13 @@ class BaseTTS(BaseModel):
|
||||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
print(f" > `speakers.json` is saved to {output_path}.")
|
print(f" > `speakers.json` is saved to {output_path}.")
|
||||||
print(" > `speakers_file` is updated in the config.json.")
|
print(" > `speakers_file` is updated in the config.json.")
|
||||||
|
|
||||||
|
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||||
|
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||||
|
self.language_manager.save_language_ids_to_file(output_path)
|
||||||
|
trainer.config.language_ids_file = output_path
|
||||||
|
if hasattr(trainer.config, "model_args"):
|
||||||
|
trainer.config.model_args.language_ids_file = output_path
|
||||||
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
|
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||||
|
print(" > `language_ids_file` is updated in the config.json.")
|
||||||
|
|
|
@ -16,8 +16,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
@ -158,6 +158,9 @@ class VitsArgs(Coqpit):
|
||||||
num_languages (int):
|
num_languages (int):
|
||||||
Number of languages for the language embedding layer. Defaults to 0.
|
Number of languages for the language embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language mapping file for the Language Manager. Defaults to None.
|
||||||
|
|
||||||
use_speaker_encoder_as_loss (bool):
|
use_speaker_encoder_as_loss (bool):
|
||||||
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||||
|
|
||||||
|
@ -225,6 +228,7 @@ class VitsArgs(Coqpit):
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
embedded_language_dim: int = 4
|
embedded_language_dim: int = 4
|
||||||
num_languages: int = 0
|
num_languages: int = 0
|
||||||
|
language_ids_file: str = None
|
||||||
use_speaker_encoder_as_loss: bool = False
|
use_speaker_encoder_as_loss: bool = False
|
||||||
speaker_encoder_config_path: str = ""
|
speaker_encoder_config_path: str = ""
|
||||||
speaker_encoder_model_path: str = ""
|
speaker_encoder_model_path: str = ""
|
||||||
|
@ -265,13 +269,18 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
language_manager: LanguageManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.END2END = True
|
self.END2END = True
|
||||||
self.speaker_manager = speaker_manager
|
self.speaker_manager = speaker_manager
|
||||||
self.audio_config = config["audio"]
|
self.language_manager = language_manager
|
||||||
if config.__class__.__name__ == "VitsConfig":
|
if config.__class__.__name__ == "VitsConfig":
|
||||||
# loading from VitsConfig
|
# loading from VitsConfig
|
||||||
if "num_chars" not in config:
|
if "num_chars" not in config:
|
||||||
|
@ -443,43 +452,20 @@ class Vits(BaseTTS):
|
||||||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
|
|
||||||
if config.use_speaker_encoder_as_loss:
|
def init_multilingual(self, config: Coqpit):
|
||||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
|
||||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
|
||||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
|
||||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
|
||||||
for param in self.speaker_encoder.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
print(" > External Speaker Encoder Loaded !!")
|
|
||||||
|
|
||||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
|
||||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
|
||||||
else:
|
|
||||||
self.audio_transform = None
|
|
||||||
else:
|
|
||||||
self.audio_transform = None
|
|
||||||
self.speaker_encoder = None
|
|
||||||
|
|
||||||
def init_multilingual(self, config: Coqpit, data: List = None):
|
|
||||||
"""Initialize multilingual modules of a model.
|
"""Initialize multilingual modules of a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config = config.model_args
|
config = config.model_args
|
||||||
# init language manager
|
|
||||||
self.language_manager = LanguageManager(config, data=data)
|
|
||||||
|
|
||||||
# init language embedding layer
|
if config.language_ids_file is not None:
|
||||||
if config.use_language_embedding:
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
if config.num_languages > 0 and self.language_manager.num_languages == 0:
|
|
||||||
self.num_languages = config.num_languages
|
if config.use_language_embedding and self.language_manager:
|
||||||
else:
|
|
||||||
self.num_languages = self.language_manager.num_languages
|
self.num_languages = self.language_manager.num_languages
|
||||||
|
|
||||||
self.embedded_language_dim = config.embedded_language_dim
|
self.embedded_language_dim = config.embedded_language_dim
|
||||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -14,11 +14,13 @@ class LanguageManager:
|
||||||
in a way that can be queried by language.
|
in a way that can be queried by language.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||||
TTS models. Defaults to "".
|
TTS models. Defaults to "".
|
||||||
|
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
|
||||||
>>> language_id_mapper = manager.language_ids
|
>>> language_id_mapper = manager.language_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -26,10 +28,14 @@ class LanguageManager:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
language_id_file_path: str = "",
|
language_ids_file_path: str = "",
|
||||||
|
config: Coqpit = None,
|
||||||
):
|
):
|
||||||
if language_id_file_path:
|
if language_ids_file_path:
|
||||||
self.set_language_ids_from_file(language_id_file_path)
|
self.set_language_ids_from_file(language_ids_file_path)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
self.set_language_ids_from_config(config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_json(json_file_path: str) -> Dict:
|
def _load_json(json_file_path: str) -> Dict:
|
||||||
|
@ -50,27 +56,30 @@ class LanguageManager:
|
||||||
return list(self.language_id_mapping.keys())
|
return list(self.language_id_mapping.keys())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
|
def parse_language_ids_from_config(c: Coqpit) -> Dict:
|
||||||
"""Parse language IDs from data samples retured by `load_meta_data()`.
|
"""Set language id from config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items (list): Data sampled returned by `load_meta_data()`.
|
c (Coqpit): Config
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, int]: language IDs and number of languages.
|
Tuple[Dict, int]: Language ID mapping and the number of languages.
|
||||||
"""
|
"""
|
||||||
languages = sorted({item[3] for item in items})
|
languages = set({})
|
||||||
language_ids = {name: i for i, name in enumerate(languages)}
|
for dataset in c.datasets:
|
||||||
num_languages = len(language_ids)
|
if "language" in dataset:
|
||||||
return language_ids, num_languages
|
languages.add(dataset["language"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||||
|
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
||||||
|
|
||||||
def set_language_ids_from_data(self, items: List) -> None:
|
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||||
"""Set language IDs from data samples.
|
"""Set language IDs from config samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items (List): Data sampled returned by `load_meta_data()`.
|
items (List): Data sampled returned by `load_meta_data()`.
|
||||||
"""
|
"""
|
||||||
self.language_id_mapping, _ = self.parse_languages_from_data(items)
|
self.language_id_mapping = self.parse_language_ids_from_config(c)
|
||||||
|
|
||||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||||
"""Load language ids from a json file.
|
"""Load language ids from a json file.
|
||||||
|
@ -102,36 +111,6 @@ def _set_file_path(path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
|
||||||
"""Initiate a `LanguageManager` instance by the provided config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
c (Coqpit): Model configuration.
|
|
||||||
restore_path (str): Path to a previous training folder.
|
|
||||||
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
|
|
||||||
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeakerManager: initialized and ready to use instance.
|
|
||||||
"""
|
|
||||||
language_manager = LanguageManager()
|
|
||||||
if c.use_language_embedding:
|
|
||||||
if data is not None:
|
|
||||||
language_manager.set_language_ids_from_data(data)
|
|
||||||
if restore_path:
|
|
||||||
language_file = _set_file_path(restore_path)
|
|
||||||
# restoring language manager from a previous run.
|
|
||||||
if language_file:
|
|
||||||
language_manager.set_language_ids_from_file(language_file)
|
|
||||||
if language_manager.num_languages > 0:
|
|
||||||
print(
|
|
||||||
" > Language manager is loaded with {} languages: {}".format(
|
|
||||||
language_manager.num_languages, ", ".join(language_manager.language_names)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return language_manager
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_weighted_sampler(items: list):
|
def get_language_weighted_sampler(items: list):
|
||||||
language_names = np.array([item[3] for item in items])
|
language_names = np.array([item[3] for item in items])
|
||||||
unique_language_names = np.unique(language_names).tolist()
|
unique_language_names = np.unique(language_names).tolist()
|
||||||
|
|
Loading…
Reference in New Issue