mirror of https://github.com/coqui-ai/TTS.git
Add multilingual training support to the VITS model
This commit is contained in:
parent
829ee55b04
commit
d0e3647db6
|
@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
|
|||
dropout_p (float): Dropout rate used after each conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.in_channels = in_channels
|
||||
|
@ -36,7 +36,10 @@ class DurationPredictor(nn.Module):
|
|||
if cond_channels is not None and cond_channels != 0:
|
||||
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None, lang_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
|
@ -45,6 +48,10 @@ class DurationPredictor(nn.Module):
|
|||
"""
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
if lang_emb is not None:
|
||||
x = x + self.cond_lang(lang_emb)
|
||||
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
|
|
|
@ -37,6 +37,7 @@ class TextEncoder(nn.Module):
|
|||
num_layers: int,
|
||||
kernel_size: int,
|
||||
dropout_p: float,
|
||||
language_emb_dim: int = None,
|
||||
):
|
||||
"""Text Encoder for VITS model.
|
||||
|
||||
|
@ -55,8 +56,12 @@ class TextEncoder(nn.Module):
|
|||
self.hidden_channels = hidden_channels
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
|
||||
if language_emb_dim:
|
||||
hidden_channels += language_emb_dim
|
||||
|
||||
self.encoder = RelativePositionTransformer(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
|
@ -72,13 +77,18 @@ class TextEncoder(nn.Module):
|
|||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
def forward(self, x, x_lengths, lang_emb=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_length: :math:`[B]`
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
|
||||
# concat the lang emb in embedding chars
|
||||
if lang_emb is not None:
|
||||
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
|
||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -205,7 +205,10 @@ class StochasticDurationPredictor(nn.Module):
|
|||
if cond_channels != 0 and cond_channels is not None:
|
||||
self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
|
||||
if language_emb_dim != 0 and language_emb_dim is not None:
|
||||
self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
|
@ -217,6 +220,10 @@ class StochasticDurationPredictor(nn.Module):
|
|||
x = self.pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
if lang_emb is not None:
|
||||
x = x + self.cond_lang(lang_emb)
|
||||
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
|
|
|
@ -287,8 +287,9 @@ class BaseTTS(BaseModel):
|
|||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -17,6 +17,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -189,6 +190,9 @@ class VitsArgs(Coqpit):
|
|||
d_vector_file: str = None
|
||||
d_vector_dim: int = 0
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
embedded_language_dim: int = 4
|
||||
num_languages: int = 0
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -247,6 +251,7 @@ class Vits(BaseTTS):
|
|||
self.args = args
|
||||
|
||||
self.init_multispeaker(config)
|
||||
self.init_multilingual(config)
|
||||
|
||||
self.length_scale = args.length_scale
|
||||
self.noise_scale = args.noise_scale
|
||||
|
@ -265,6 +270,7 @@ class Vits(BaseTTS):
|
|||
args.num_layers_text_encoder,
|
||||
args.kernel_size_text_encoder,
|
||||
args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -288,16 +294,22 @@ class Vits(BaseTTS):
|
|||
|
||||
if args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
args.hidden_channels,
|
||||
args.hidden_channels + self.embedded_language_dim,
|
||||
192,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
4,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim
|
||||
args.hidden_channels + self.embedded_language_dim,
|
||||
256,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
|
@ -356,17 +368,40 @@ class Vits(BaseTTS):
|
|||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
def init_multilingual(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multilingual modules of a model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
# init language manager
|
||||
self.language_manager = LanguageManager(config, data=data)
|
||||
|
||||
# init language embedding layer
|
||||
if config.use_language_embedding:
|
||||
self.embedded_language_dim = config.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.language_manager.num_languages, self.embedded_language_dim)
|
||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||
else:
|
||||
self.embedded_language_dim = 0
|
||||
self.emb_l = None
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g = None, None
|
||||
sid, g, lid = None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||
return sid, g
|
||||
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
||||
lid = aux_input["language_ids"]
|
||||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
|
@ -378,7 +413,7 @@ class Vits(BaseTTS):
|
|||
x_lengths: torch.tensor,
|
||||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None},
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
||||
|
@ -401,13 +436,19 @@ class Vits(BaseTTS):
|
|||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1 and sid is not None and not self.use_d_vector:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
if self.args.use_language_embedding:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
|
@ -433,6 +474,7 @@ class Vits(BaseTTS):
|
|||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
|
@ -441,6 +483,7 @@ class Vits(BaseTTS):
|
|||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
|
@ -467,25 +510,30 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
"""
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
if self.num_speakers > 0 and sid is not None:
|
||||
# speaker embedding
|
||||
if self.num_speakers > 0 and sid:
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
# language embedding
|
||||
if self.args.use_language_embedding:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
|
||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g)
|
||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
@ -537,6 +585,7 @@ class Vits(BaseTTS):
|
|||
linear_input = batch["linear_input"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
language_ids = batch["language_ids"]
|
||||
waveform = batch["waveform"]
|
||||
|
||||
# generator pass
|
||||
|
@ -545,7 +594,7 @@ class Vits(BaseTTS):
|
|||
text_lengths,
|
||||
linear_input.transpose(1, 2),
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
|
@ -581,6 +630,14 @@ class Vits(BaseTTS):
|
|||
loss_duration=outputs["loss_duration"],
|
||||
)
|
||||
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["loss_duration"]
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
outputs = {}
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import fsspec
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, List
|
||||
from coqpit import Coqpit
|
||||
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
class LanguageManager:
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
||||
Args:
|
||||
language_id_file_path (str, optional): Path to the metafile that maps language names to ids used by
|
||||
TTS models. Defaults to "".
|
||||
|
||||
Examples:
|
||||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
||||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
num_languages: int = 0
|
||||
language_id_mapping: Dict = {}
|
||||
def __init__(
|
||||
self,
|
||||
language_id_file_path: str = "",
|
||||
):
|
||||
if language_id_file_path:
|
||||
self.set_language_ids_from_file(language_id_file_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with fsspec.open(json_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@property
|
||||
def num_languages(self) -> int:
|
||||
return len(list(self.language_id_mapping.keys()))
|
||||
|
||||
@property
|
||||
def language_names(self) -> List:
|
||||
return list(self.language_id_mapping.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_languages_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse language IDs from data samples retured by `load_meta_data()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_meta_data()`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: language IDs and number of languages.
|
||||
"""
|
||||
languages = sorted({item[3] for item in items})
|
||||
language_ids = {name: i for i, name in enumerate(languages)}
|
||||
num_languages = len(language_ids)
|
||||
return language_ids, num_languages
|
||||
|
||||
def set_language_ids_from_data(self, items: List) -> None:
|
||||
"""Set language IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
"""
|
||||
self.language_id_mapping, _ = self.parse_languages_from_data(items)
|
||||
|
||||
def set_language_ids_from_file(self, file_path: str) -> None:
|
||||
"""Load language ids from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.language_id_mapping = self._load_json(file_path)
|
||||
self.num_languages = len(self.language_id_mapping)
|
||||
|
||||
def save_language_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save language IDs to a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the language_ids.json under the given path or the above it.
|
||||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
path_restore = os.path.join(os.path.dirname(path), "language_ids.json")
|
||||
path_continue = os.path.join(path, "language_ids.json")
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
if fs.exists(path_restore):
|
||||
return path_restore
|
||||
if fs.exists(path_continue):
|
||||
return path_continue
|
||||
return None
|
||||
|
||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> LanguageManager:
|
||||
"""Initiate a `LanguageManager` instance by the provided config.
|
||||
|
||||
Args:
|
||||
c (Coqpit): Model configuration.
|
||||
restore_path (str): Path to a previous training folder.
|
||||
data (List): Data sampled returned by `load_meta_data()`. Defaults to None.
|
||||
out_path (str, optional): Save the generated language IDs to a output path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
SpeakerManager: initialized and ready to use instance.
|
||||
"""
|
||||
language_manager = LanguageManager()
|
||||
if c.use_language_embedding:
|
||||
if data is not None:
|
||||
language_manager.set_language_ids_from_data(data)
|
||||
if restore_path:
|
||||
language_file = _set_file_path(restore_path)
|
||||
# restoring language manager from a previous run.
|
||||
if language_file:
|
||||
language_manager.set_language_ids_from_file(language_file)
|
||||
if language_manager.num_languages > 0:
|
||||
print(
|
||||
" > Language manager is loaded with {} languages: {}".format(
|
||||
language_manager.num_languages, ", ".join(language_manager.language_names)
|
||||
)
|
||||
)
|
||||
return language_manager
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
weight_language = 1. / language_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
|
@ -379,11 +379,14 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||
# new speaker manager with speaker IDs file.
|
||||
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
|
||||
if speaker_manager.num_speakers > 0:
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# save file if path is defined
|
||||
if out_path:
|
||||
out_file_path = os.path.join(out_path, "speakers.json")
|
||||
|
|
Loading…
Reference in New Issue