Add multilingual training support to the VITS model

This commit is contained in:
Edresson 2021-08-13 21:40:34 -03:00 committed by Eren Gölge
parent 829ee55b04
commit d0e3647db6
7 changed files with 248 additions and 25 deletions

View File

@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
dropout_p (float): Dropout rate used after each conv layer.
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
super().__init__()
# class arguments
self.in_channels = in_channels
@ -36,7 +36,10 @@ class DurationPredictor(nn.Module):
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
if language_emb_dim != 0 and language_emb_dim is not None:
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
def forward(self, x, x_mask, g=None, lang_emb=None):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -45,6 +48,10 @@ class DurationPredictor(nn.Module):
"""
if g is not None:
x = x + self.cond(g)
if lang_emb is not None:
x = x + self.cond_lang(lang_emb)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)

View File

@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
num_layers: int,
kernel_size: int,
dropout_p: float,
language_emb_dim: int = None,
):
"""Text Encoder for VITS model.
@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
if language_emb_dim:
hidden_channels += language_emb_dim
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
def forward(self, x, x_lengths, lang_emb=None):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
# concat the lang emb in embedding chars
if lang_emb is not None:
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

View File

@ -178,7 +178,7 @@ class StochasticDurationPredictor(nn.Module):
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None
):
super().__init__()
@ -205,7 +205,10 @@ class StochasticDurationPredictor(nn.Module):
if cond_channels != 0 and cond_channels is not None:
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
if language_emb_dim != 0 and language_emb_dim is not None:
self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -217,6 +220,10 @@ class StochasticDurationPredictor(nn.Module):
x = self.pre(x)
if g is not None:
x = x + self.cond(g)
if lang_emb is not None:
x = x + self.cond_lang(lang_emb)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask

View File

@ -287,8 +287,9 @@ class BaseTTS(BaseModel):
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
sampler = get_language_weighted_sampler(dataset.items)
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items)
loader = DataLoader(
dataset,

View File

@ -17,6 +17,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -189,6 +190,9 @@ class VitsArgs(Coqpit):
d_vector_file: str = None
d_vector_dim: int = 0
detach_dp_input: bool = True
use_language_embedding: bool = False
embedded_language_dim: int = 4
num_languages: int = 0
class Vits(BaseTTS):
@ -247,6 +251,7 @@ class Vits(BaseTTS):
self.args = args
self.init_multispeaker(config)
self.init_multilingual(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
@ -265,6 +270,7 @@ class Vits(BaseTTS):
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim
)
self.posterior_encoder = PosteriorEncoder(
@ -288,16 +294,22 @@ class Vits(BaseTTS):
if args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels,
args.hidden_channels + self.embedded_language_dim,
192,
3,
args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim,
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
args.hidden_channels + self.embedded_language_dim,
256,
3,
args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim,
)
self.waveform_decoder = HifiganGenerator(
@ -356,17 +368,40 @@ class Vits(BaseTTS):
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
def init_multilingual(self, config: Coqpit, data: List = None):
"""Initialize multilingual modules of a model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
if hasattr(config, "model_args"):
config = config.model_args
# init language manager
self.language_manager = LanguageManager(config, data=data)
# init language embedding layer
if config.use_language_embedding:
self.embedded_language_dim = config.embedded_language_dim
self.emb_l = nn.Embedding(self.language_manager.num_languages, self.embedded_language_dim)
torch.nn.init.xavier_uniform_(self.emb_l.weight)
else:
self.embedded_language_dim = 0
self.emb_l = None
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g = None, None
sid, g, lid = None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
return sid, g
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
lid = aux_input["language_ids"]
return sid, g, lid
def get_aux_input(self, aux_input: Dict):
sid, g = self._set_cond_input(aux_input)
@ -378,7 +413,7 @@ class Vits(BaseTTS):
x_lengths: torch.tensor,
y: torch.tensor,
y_lengths: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None},
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
) -> Dict:
"""Forward pass of the model.
@ -401,13 +436,19 @@ class Vits(BaseTTS):
- speaker_ids: :math:`[B]`
"""
outputs = {}
sid, g = self._set_cond_input(aux_input)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
sid, g, lid = self._set_cond_input(aux_input)
# speaker embedding
if self.num_speakers > 1 and sid is not None and not self.use_d_vector:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
if self.args.use_language_embedding:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -433,6 +474,7 @@ class Vits(BaseTTS):
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = loss_duration / torch.sum(x_mask)
else:
@ -441,6 +483,7 @@ class Vits(BaseTTS):
x.detach() if self.args.detach_dp_input else x,
x_mask,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
@ -467,25 +510,30 @@ class Vits(BaseTTS):
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
"""
Shapes:
- x: :math:`[B, T_seq]`
- d_vectors: :math:`[B, C, 1]`
- speaker_ids: :math:`[B]`
"""
sid, g = self._set_cond_input(aux_input)
sid, g, lid = self._set_cond_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
if self.num_speakers > 0 and sid is not None:
# speaker embedding
if self.num_speakers > 0 and sid:
g = self.emb_g(sid).unsqueeze(-1)
# language embedding
if self.args.use_language_embedding:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb)
else:
logw = self.duration_predictor(x, x_mask, g=g)
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
@ -537,6 +585,7 @@ class Vits(BaseTTS):
linear_input = batch["linear_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
language_ids = batch["language_ids"]
waveform = batch["waveform"]
# generator pass
@ -545,7 +594,7 @@ class Vits(BaseTTS):
text_lengths,
linear_input.transpose(1, 2),
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
)
# cache tensors for the discriminator
@ -581,6 +630,14 @@ class Vits(BaseTTS):
loss_duration=outputs["loss_duration"],
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["loss_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}

138
TTS/tts/utils/languages.py Normal file
View File

@ -0,0 +1,138 @@
import os
import json
import torch
import fsspec
import numpy as np
from typing import Dict, Tuple, List
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
Args:
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
TTS models. Defaults to "".
Examples:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
>>> language_id_mapper = manager.language_ids
"""
num_languages: int = 0
language_id_mapping: Dict = {}
def __init__(
self,
language_id_file_path: str = "",
):
if language_id_file_path:
self.set_language_ids_from_file(language_id_file_path)
@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_languages(self) -> int:
return len(list(self.language_id_mapping.keys()))
@property
def language_names(self) -> List:
return list(self.language_id_mapping.keys())
@staticmethod
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
"""Parse language IDs from data samples retured by `load_meta_data()`.
Args:
items (list): Data sampled returned by `load_meta_data()`.
Returns:
Tuple[Dict, int]: language IDs and number of languages.
"""
languages = sorted({item[3] for item in items})
language_ids = {name: i for i, name in enumerate(languages)}
num_languages = len(language_ids)
return language_ids, num_languages
def set_language_ids_from_data(self, items: List) -> None:
"""Set language IDs from data samples.
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.language_id_mapping, _ = self.parse_languages_from_data(items)
def set_language_ids_from_file(self, file_path: str) -> None:
"""Load language ids from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.language_id_mapping = self._load_json(file_path)
self.num_languages = len(self.language_id_mapping)
def save_language_ids_to_file(self, file_path: str) -> None:
"""Save language IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.language_id_mapping)
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
path_continue = os.path.join(path, "language_ids.json")
fs = fsspec.get_mapper(path).fs
if fs.exists(path_restore):
return path_restore
if fs.exists(path_continue):
return path_continue
return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config.
Args:
c (Coqpit): Model configuration.
restore_path (str): Path to a previous training folder.
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
Returns:
SpeakerManager: initialized and ready to use instance.
"""
language_manager = LanguageManager()
if c.use_language_embedding:
if data is not None:
language_manager.set_language_ids_from_data(data)
if restore_path:
language_file = _set_file_path(restore_path)
# restoring language manager from a previous run.
if language_file:
language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0:
print(
" > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names)
)
)
return language_manager
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1. / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -379,11 +379,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
# new speaker manager with speaker IDs file.
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
if speaker_manager.num_speakers > 0:
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)
)
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speakers.json")