Move multilingual logic out of the trainer

This commit is contained in:
WeberJulian 2021-11-27 22:55:21 +01:00 committed by Eren Gölge
parent b909a3b63e
commit 352b4be104
7 changed files with 82 additions and 99 deletions

View File

@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -60,8 +61,17 @@ def main():
else: else:
speaker_manager = None speaker_manager = None
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
else:
config.num_languages = language_manager.num_languages
else:
language_manager = None
# init the model from config # init the model from config
model = setup_model(config, speaker_manager) model = setup_model(config, speaker_manager, language_manager)
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(

View File

@ -260,22 +260,6 @@ class Trainer:
else: else:
self.run_get_model(self.config, get_model) self.run_get_model(self.config, get_model)
if hasattr(self.model, "init_multilingual"):
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
self.model.language_manager.save_language_ids_to_file(
os.path.join(self.output_path, "language_ids.json")
)
if hasattr(self.config, "model_args"):
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
else:
self.config.num_languages = self.model.language_manager.num_languages
# update config file
copy_model_files(self.config, self.output_path)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.model) self.criterion = self.get_criterion(self.model)

View File

@ -85,6 +85,12 @@ class VitsConfig(BaseTTSConfig):
test_sentences (List[List]): test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing. List of sentences with speaker and language information to be used for testing.
language_ids_file (str):
Path to the language ids file.
use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.
Note: Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
@ -147,6 +153,8 @@ class VitsConfig(BaseTTSConfig):
use_speaker_embedding: bool = False use_speaker_embedding: bool = False
speakers_file: str = None speakers_file: str = None
speaker_embedding_channels: int = 256 speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False
# use d-vectors # use d-vectors
use_d_vector_file: bool = False use_d_vector_file: bool = False

View File

@ -2,7 +2,11 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
def setup_model(config, speaker_manager: "SpeakerManager" = None): def setup_model(
config,
speaker_manager: "SpeakerManager" = None,
language_manager: "LanguageManager" = None
):
print(" > Using model: {}".format(config.model)) print(" > Using model: {}".format(config.model))
# fetch the right model implementation. # fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None: if "base_model" in config and config["base_model"] is not None:
@ -31,6 +35,9 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
config.model_params.num_chars = num_chars config.model_params.num_chars = num_chars
if "model_args" in config: if "model_args" in config:
config.model_args.num_chars = num_chars config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager) model = MyModel(config, speaker_manager=speaker_manager)
return model return model

View File

@ -419,8 +419,7 @@ class BaseTTS(BaseModel):
return test_figures, test_audios return test_figures, test_audios
def on_init_start(self, trainer): def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the """Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
speakers.json file path."""
if self.speaker_manager is not None: if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json") output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path) self.speaker_manager.save_speaker_ids_to_file(output_path)
@ -431,3 +430,13 @@ class BaseTTS(BaseModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.") print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.") print(" > `speakers_file` is updated in the config.json.")
if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")

View File

@ -16,8 +16,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -158,6 +158,9 @@ class VitsArgs(Coqpit):
num_languages (int): num_languages (int):
Number of languages for the language embedding layer. Defaults to 0. Number of languages for the language embedding layer. Defaults to 0.
language_ids_file (str):
Path to the language mapping file for the Language Manager. Defaults to None.
use_speaker_encoder_as_loss (bool): use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
@ -225,6 +228,7 @@ class VitsArgs(Coqpit):
use_language_embedding: bool = False use_language_embedding: bool = False
embedded_language_dim: int = 4 embedded_language_dim: int = 4
num_languages: int = 0 num_languages: int = 0
language_ids_file: str = None
use_speaker_encoder_as_loss: bool = False use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = "" speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = "" speaker_encoder_model_path: str = ""
@ -265,13 +269,18 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
self,
config: Coqpit,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__(config) super().__init__(config)
self.END2END = True self.END2END = True
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.audio_config = config["audio"] self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig": if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig # loading from VitsConfig
if "num_chars" not in config: if "num_chars" not in config:
@ -443,43 +452,20 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim self.embedded_speaker_dim = config.d_vector_dim
if config.use_speaker_encoder_as_loss: def init_multilingual(self, config: Coqpit):
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
else:
self.audio_transform = None
else:
self.audio_transform = None
self.speaker_encoder = None
def init_multilingual(self, config: Coqpit, data: List = None):
"""Initialize multilingual modules of a model. """Initialize multilingual modules of a model.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
""" """
if hasattr(config, "model_args"): if hasattr(config, "model_args"):
config = config.model_args config = config.model_args
# init language manager
self.language_manager = LanguageManager(config, data=data)
# init language embedding layer if config.language_ids_file is not None:
if config.use_language_embedding: self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if config.num_languages > 0 and self.language_manager.num_languages == 0:
self.num_languages = config.num_languages if config.use_language_embedding and self.language_manager:
else:
self.num_languages = self.language_manager.num_languages self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = config.embedded_language_dim self.embedded_language_dim = config.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
torch.nn.init.xavier_uniform_(self.emb_l.weight) torch.nn.init.xavier_uniform_(self.emb_l.weight)

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import Dict, List, Tuple from typing import Dict, List
import fsspec import fsspec
import numpy as np import numpy as np
@ -14,11 +14,13 @@ class LanguageManager:
in a way that can be queried by language. in a way that can be queried by language.
Args: Args:
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
TTS models. Defaults to "". TTS models. Defaults to "".
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
Defaults to None.
Examples: Examples:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path) >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
>>> language_id_mapper = manager.language_ids >>> language_id_mapper = manager.language_ids
""" """
@ -26,10 +28,14 @@ class LanguageManager:
def __init__( def __init__(
self, self,
language_id_file_path: str = "", language_ids_file_path: str = "",
config: Coqpit = None,
): ):
if language_id_file_path: if language_ids_file_path:
self.set_language_ids_from_file(language_id_file_path) self.set_language_ids_from_file(language_ids_file_path)
if config:
self.set_language_ids_from_config(config)
@staticmethod @staticmethod
def _load_json(json_file_path: str) -> Dict: def _load_json(json_file_path: str) -> Dict:
@ -50,27 +56,30 @@ class LanguageManager:
return list(self.language_id_mapping.keys()) return list(self.language_id_mapping.keys())
@staticmethod @staticmethod
def parse_languages_from_data(items: list) -> Tuple[Dict, int]: def parse_language_ids_from_config(c: Coqpit) -> Dict:
"""Parse language IDs from data samples retured by `load_meta_data()`. """Set language id from config.
Args: Args:
items (list): Data sampled returned by `load_meta_data()`. c (Coqpit): Config
Returns: Returns:
Tuple[Dict, int]: language IDs and number of languages. Tuple[Dict, int]: Language ID mapping and the number of languages.
""" """
languages = sorted({item[3] for item in items}) languages = set({})
language_ids = {name: i for i, name in enumerate(languages)} for dataset in c.datasets:
num_languages = len(language_ids) if "language" in dataset:
return language_ids, num_languages languages.add(dataset["language"])
else:
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
return {name: i for i, name in enumerate(sorted(list(languages)))}
def set_language_ids_from_data(self, items: List) -> None: def set_language_ids_from_config(self, c: Coqpit) -> None:
"""Set language IDs from data samples. """Set language IDs from config samples.
Args: Args:
items (List): Data sampled returned by `load_meta_data()`. items (List): Data sampled returned by `load_meta_data()`.
""" """
self.language_id_mapping, _ = self.parse_languages_from_data(items) self.language_id_mapping = self.parse_language_ids_from_config(c)
def set_language_ids_from_file(self, file_path: str) -> None: def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file. """Load language ids from a json file.
@ -102,36 +111,6 @@ def _set_file_path(path):
return None return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config.
Args:
c (Coqpit): Model configuration.
restore_path (str): Path to a previous training folder.
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
Returns:
SpeakerManager: initialized and ready to use instance.
"""
language_manager = LanguageManager()
if c.use_language_embedding:
if data is not None:
language_manager.set_language_ids_from_data(data)
if restore_path:
language_file = _set_file_path(restore_path)
# restoring language manager from a previous run.
if language_file:
language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0:
print(
" > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names)
)
)
return language_manager
def get_language_weighted_sampler(items: list): def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items]) language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist() unique_language_names = np.unique(language_names).tolist()