Move multilingual logic out of the trainer

This commit is contained in:
WeberJulian 2021-11-27 22:55:21 +01:00 committed by Eren Gölge
parent 818dc4ccd8
commit 6b03943526
7 changed files with 82 additions and 99 deletions

View File

@ -5,6 +5,7 @@ from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.utils.audio import AudioProcessor
@ -60,8 +61,17 @@ def main():
else:
speaker_manager = None
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
else:
config.num_languages = language_manager.num_languages
else:
language_manager = None
# init the model from config
model = setup_model(config, speaker_manager)
model = setup_model(config, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(

View File

@ -260,22 +260,6 @@ class Trainer:
else:
self.run_get_model(self.config, get_model)
if hasattr(self.model, "init_multilingual"):
self.model.init_multilingual(self.config, self.train_samples + self.eval_samples)
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
self.model.language_manager.save_language_ids_to_file(
os.path.join(self.output_path, "language_ids.json")
)
if hasattr(self.config, "model_args"):
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
else:
self.config.num_languages = self.model.language_manager.num_languages
# update config file
copy_model_files(self.config, self.output_path)
# setup criterion
self.criterion = self.get_criterion(self.model)

View File

@ -85,6 +85,12 @@ class VitsConfig(BaseTTSConfig):
test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.
language_ids_file (str):
Path to the language ids file.
use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
@ -147,6 +153,8 @@ class VitsConfig(BaseTTSConfig):
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False
# use d-vectors
use_d_vector_file: bool = False

View File

@ -2,7 +2,11 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
def setup_model(config, speaker_manager: "SpeakerManager" = None):
def setup_model(
config,
speaker_manager: "SpeakerManager" = None,
language_manager: "LanguageManager" = None
):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
@ -31,7 +35,10 @@ def setup_model(config, speaker_manager: "SpeakerManager" = None):
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
model = MyModel(config, speaker_manager=speaker_manager)
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)
return model

View File

@ -419,8 +419,7 @@ class BaseTTS(BaseModel):
return test_figures, test_audios
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
@ -431,3 +430,13 @@ class BaseTTS(BaseModel):
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")

View File

@ -16,8 +16,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -158,6 +158,9 @@ class VitsArgs(Coqpit):
num_languages (int):
Number of languages for the language embedding layer. Defaults to 0.
language_ids_file (str):
Path to the language mapping file for the Language Manager. Defaults to None.
use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
@ -225,6 +228,7 @@ class VitsArgs(Coqpit):
use_language_embedding: bool = False
embedded_language_dim: int = 4
num_languages: int = 0
language_ids_file: str = None
use_speaker_encoder_as_loss: bool = False
speaker_encoder_config_path: str = ""
speaker_encoder_model_path: str = ""
@ -265,13 +269,18 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: Coqpit,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__(config)
self.END2END = True
self.speaker_manager = speaker_manager
self.audio_config = config["audio"]
self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
@ -443,43 +452,20 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
else:
self.audio_transform = None
else:
self.audio_transform = None
self.speaker_encoder = None
def init_multilingual(self, config: Coqpit, data: List = None):
def init_multilingual(self, config: Coqpit):
"""Initialize multilingual modules of a model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
if hasattr(config, "model_args"):
config = config.model_args
# init language manager
self.language_manager = LanguageManager(config, data=data)
# init language embedding layer
if config.use_language_embedding:
if config.num_languages > 0 and self.language_manager.num_languages == 0:
self.num_languages = config.num_languages
else:
self.num_languages = self.language_manager.num_languages
if config.language_ids_file is not None:
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if config.use_language_embedding and self.language_manager:
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = config.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
torch.nn.init.xavier_uniform_(self.emb_l.weight)

View File

@ -1,6 +1,6 @@
import json
import os
from typing import Dict, List, Tuple
from typing import Dict, List
import fsspec
import numpy as np
@ -14,11 +14,13 @@ class LanguageManager:
in a way that can be queried by language.
Args:
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by
TTS models. Defaults to "".
config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed.
Defaults to None.
Examples:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
>>> manager = LanguageManager(language_ids_file_path=language_ids_file_path)
>>> language_id_mapper = manager.language_ids
"""
@ -26,10 +28,14 @@ class LanguageManager:
def __init__(
self,
language_id_file_path: str = "",
language_ids_file_path: str = "",
config: Coqpit = None,
):
if language_id_file_path:
self.set_language_ids_from_file(language_id_file_path)
if language_ids_file_path:
self.set_language_ids_from_file(language_ids_file_path)
if config:
self.set_language_ids_from_config(config)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
@ -50,27 +56,30 @@ class LanguageManager:
return list(self.language_id_mapping.keys())
@staticmethod
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
"""Parse language IDs from data samples retured by `load_meta_data()`.
def parse_language_ids_from_config(c: Coqpit) -> Dict:
"""Set language id from config.
Args:
items (list): Data sampled returned by `load_meta_data()`.
c (Coqpit): Config
Returns:
Tuple[Dict, int]: language IDs and number of languages.
Tuple[Dict, int]: Language ID mapping and the number of languages.
"""
languages = sorted({item[3] for item in items})
language_ids = {name: i for i, name in enumerate(languages)}
num_languages = len(language_ids)
return language_ids, num_languages
languages = set({})
for dataset in c.datasets:
if "language" in dataset:
languages.add(dataset["language"])
else:
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
return {name: i for i, name in enumerate(sorted(list(languages)))}
def set_language_ids_from_data(self, items: List) -> None:
"""Set language IDs from data samples.
def set_language_ids_from_config(self, c: Coqpit) -> None:
"""Set language IDs from config samples.
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.language_id_mapping, _ = self.parse_languages_from_data(items)
self.language_id_mapping = self.parse_language_ids_from_config(c)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
@ -102,36 +111,6 @@ def _set_file_path(path):
return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config.
Args:
c (Coqpit): Model configuration.
restore_path (str): Path to a previous training folder.
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
Returns:
SpeakerManager: initialized and ready to use instance.
"""
language_manager = LanguageManager()
if c.use_language_embedding:
if data is not None:
language_manager.set_language_ids_from_data(data)
if restore_path:
language_file = _set_file_path(restore_path)
# restoring language manager from a previous run.
if language_file:
language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0:
print(
" > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names)
)
)
return language_manager
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist()