From dcb2374bc99e274c6ac3e3c541bd1a776d06423a Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 13 Aug 2021 21:40:34 -0300 Subject: [PATCH] Add multilingual training support to the VITS model --- TTS/tts/layers/glow_tts/duration_predictor.py | 11 +- TTS/tts/layers/vits/networks.py | 12 +- .../vits/stochastic_duration_predictor.py | 11 +- TTS/tts/models/base_tts.py | 3 +- TTS/tts/models/vits.py | 87 +++++++++-- TTS/tts/utils/languages.py | 138 ++++++++++++++++++ TTS/tts/utils/speakers.py | 11 +- 7 files changed, 248 insertions(+), 25 deletions(-) create mode 100644 TTS/tts/utils/languages.py diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index 2c0303be..f46c73a9 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -18,7 +18,7 @@ class DurationPredictor(nn.Module): dropout_p (float): Dropout rate used after each conv layer. """ - def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None): + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): super().__init__() # class arguments self.in_channels = in_channels @@ -36,7 +36,10 @@ class DurationPredictor(nn.Module): if cond_channels is not None and cond_channels != 0: self.cond = nn.Conv1d(cond_channels, in_channels, 1) - def forward(self, x, x_mask, g=None): + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1) + + def forward(self, x, x_mask, g=None, lang_emb=None): """ Shapes: - x: :math:`[B, C, T]` @@ -45,6 +48,10 @@ class DurationPredictor(nn.Module): """ if g is not None: x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index cfc8b6ac..ef426ace 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -37,6 +37,7 @@ class TextEncoder(nn.Module): num_layers: int, kernel_size: int, dropout_p: float, + language_emb_dim: int = None, ): """Text Encoder for VITS model. @@ -55,8 +56,12 @@ class TextEncoder(nn.Module): self.hidden_channels = hidden_channels self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + if language_emb_dim: + hidden_channels += language_emb_dim + self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, @@ -72,13 +77,18 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths): + def forward(self, x, x_lengths, lang_emb=None): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index 91e53da3..8ec7c866 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -178,7 +178,7 @@ class StochasticDurationPredictor(nn.Module): """ def __init__( - self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0 + self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None ): super().__init__() @@ -205,7 +205,10 @@ class StochasticDurationPredictor(nn.Module): if cond_channels != 0 and cond_channels is not None: self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) - def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0): + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -217,6 +220,10 @@ class StochasticDurationPredictor(nn.Module): x = self.pre(x) if g is not None: x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + x = self.convs(x, x_mask) x = self.proj(x) * x_mask diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c55936a8..c0d2bd78 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -287,8 +287,9 @@ class BaseTTS(BaseModel): sampler = DistributedSampler(dataset) if num_gpus > 1 else None if sampler is None: if getattr(config, "use_language_weighted_sampler", False): - sampler = get_language_weighted_sampler(dataset.items) print(" > Using Language weighted sampler") + sampler = get_language_weighted_sampler(dataset.items) + loader = DataLoader( dataset, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 417b6386..3a682ce5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -17,6 +17,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler @@ -189,6 +190,9 @@ class VitsArgs(Coqpit): d_vector_file: str = None d_vector_dim: int = 0 detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 class Vits(BaseTTS): @@ -247,6 +251,7 @@ class Vits(BaseTTS): self.args = args self.init_multispeaker(config) + self.init_multilingual(config) self.length_scale = args.length_scale self.noise_scale = args.noise_scale @@ -265,6 +270,7 @@ class Vits(BaseTTS): args.num_layers_text_encoder, args.kernel_size_text_encoder, args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim ) self.posterior_encoder = PosteriorEncoder( @@ -288,16 +294,22 @@ class Vits(BaseTTS): if args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - args.hidden_channels, + args.hidden_channels + self.embedded_language_dim, 192, 3, args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim + args.hidden_channels + self.embedded_language_dim, + 256, + 3, + args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( @@ -356,17 +368,40 @@ class Vits(BaseTTS): self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.embedded_speaker_dim = config.d_vector_dim + def init_multilingual(self, config: Coqpit, data: List = None): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + if hasattr(config, "model_args"): + config = config.model_args + # init language manager + self.language_manager = LanguageManager(config, data=data) + + # init language embedding layer + if config.use_language_embedding: + self.embedded_language_dim = config.embedded_language_dim + self.emb_l = nn.Embedding(self.language_manager.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + self.emb_l = None + @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g = None, None + sid, g, lid = None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: sid = sid.unsqueeze_(0) if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) - return sid, g + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + return sid, g, lid def get_aux_input(self, aux_input: Dict): sid, g = self._set_cond_input(aux_input) @@ -378,7 +413,7 @@ class Vits(BaseTTS): x_lengths: torch.tensor, y: torch.tensor, y_lengths: torch.tensor, - aux_input={"d_vectors": None, "speaker_ids": None}, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, ) -> Dict: """Forward pass of the model. @@ -401,13 +436,19 @@ class Vits(BaseTTS): - speaker_ids: :math:`[B]` """ outputs = {} - sid, g = self._set_cond_input(aux_input) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) + sid, g, lid = self._set_cond_input(aux_input) # speaker embedding if self.num_speakers > 1 and sid is not None and not self.use_d_vector: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + # language embedding + if self.args.use_language_embedding: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -433,6 +474,7 @@ class Vits(BaseTTS): x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: @@ -441,6 +483,7 @@ class Vits(BaseTTS): x.detach() if self.args.detach_dp_input else x, x_mask, g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration @@ -467,25 +510,30 @@ class Vits(BaseTTS): ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): """ Shapes: - x: :math:`[B, T_seq]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` """ - sid, g = self._set_cond_input(aux_input) + sid, g, lid = self._set_cond_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) - - if self.num_speakers > 0 and sid is not None: + # speaker embedding + if self.num_speakers > 0 and sid: g = self.emb_g(sid).unsqueeze(-1) + # language embedding + if self.args.use_language_embedding: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + if self.args.use_sdp: - logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp) + logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb) else: - logw = self.duration_predictor(x, x_mask, g=g) + logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) @@ -537,6 +585,7 @@ class Vits(BaseTTS): linear_input = batch["linear_input"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] waveform = batch["waveform"] # generator pass @@ -545,7 +594,7 @@ class Vits(BaseTTS): text_lengths, linear_input.transpose(1, 2), mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, ) # cache tensors for the discriminator @@ -581,6 +630,14 @@ class Vits(BaseTTS): loss_duration=outputs["loss_duration"], ) + # handle the duration loss + if self.args.use_sdp: + loss_dict["nll_duration"] = outputs["nll_duration"] + loss_dict["loss"] += outputs["nll_duration"] + else: + loss_dict["loss_duration"] = outputs["loss_duration"] + loss_dict["loss"] += outputs["loss_duration"] + elif optimizer_idx == 1: # discriminator pass outputs = {} diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py new file mode 100644 index 00000000..b87b9936 --- /dev/null +++ b/TTS/tts/utils/languages.py @@ -0,0 +1,138 @@ +import os +import json +import torch +import fsspec +import numpy as np +from typing import Dict, Tuple, List +from coqpit import Coqpit + +from torch.utils.data.sampler import WeightedRandomSampler + +class LanguageManager: + """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by language. + + Args: + language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by + TTS models. Defaults to "". + + Examples: + >>> manager = LanguageManager(language_id_file_path=language_id_file_path) + >>> language_id_mapper = manager.language_ids + """ + num_languages: int = 0 + language_id_mapping: Dict = {} + def __init__( + self, + language_id_file_path: str = "", + ): + if language_id_file_path: + self.set_language_ids_from_file(language_id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + @property + def num_languages(self) -> int: + return len(list(self.language_id_mapping.keys())) + + @property + def language_names(self) -> List: + return list(self.language_id_mapping.keys()) + + @staticmethod + def parse_languages_from_data(items: list) -> Tuple[Dict, int]: + """Parse language IDs from data samples retured by `load_meta_data()`. + + Args: + items (list): Data sampled returned by `load_meta_data()`. + + Returns: + Tuple[Dict, int]: language IDs and number of languages. + """ + languages = sorted({item[3] for item in items}) + language_ids = {name: i for i, name in enumerate(languages)} + num_languages = len(language_ids) + return language_ids, num_languages + + def set_language_ids_from_data(self, items: List) -> None: + """Set language IDs from data samples. + + Args: + items (List): Data sampled returned by `load_meta_data()`. + """ + self.language_id_mapping, _ = self.parse_languages_from_data(items) + + def set_language_ids_from_file(self, file_path: str) -> None: + """Load language ids from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.language_id_mapping = self._load_json(file_path) + self.num_languages = len(self.language_id_mapping) + + def save_language_ids_to_file(self, file_path: str) -> None: + """Save language IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.language_id_mapping) + +def _set_file_path(path): + """Find the language_ids.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "language_ids.json") + path_continue = os.path.join(path, "language_ids.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + return None + +def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> LanguageManager: + """Initiate a `LanguageManager` instance by the provided config. + + Args: + c (Coqpit): Model configuration. + restore_path (str): Path to a previous training folder. + data (List): Data sampled returned by `load_meta_data()`. Defaults to None. + out_path (str, optional): Save the generated language IDs to a output path. Defaults to None. + + Returns: + SpeakerManager: initialized and ready to use instance. + """ + language_manager = LanguageManager() + if c.use_language_embedding: + if data is not None: + language_manager.set_language_ids_from_data(data) + if restore_path: + language_file = _set_file_path(restore_path) + # restoring language manager from a previous run. + if language_file: + language_manager.set_language_ids_from_file(language_file) + if language_manager.num_languages > 0: + print( + " > Language manager is loaded with {} languages: {}".format( + language_manager.num_languages, ", ".join(language_manager.language_names) + ) + ) + return language_manager + +def get_language_weighted_sampler(items: list): + language_names = np.array([item[3] for item in items]) + unique_language_names = np.unique(language_names).tolist() + language_ids = [unique_language_names.index(l) for l in language_names] + language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) + weight_language = 1. / language_count + dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() + return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) \ No newline at end of file diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 5d883fd0..b7dd5251 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -379,11 +379,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: # new speaker manager with speaker IDs file. speaker_manager.set_speaker_ids_from_file(c.speakers_file) - print( - " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + + if speaker_manager.num_speakers > 0: + print( + " > Speaker manager is loaded with {} speakers: {}".format( + speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) + ) ) - ) + # save file if path is defined if out_path: out_file_path = os.path.join(out_path, "speakers.json")