diff --git a/TTS/trainer.py b/TTS/trainer.py index afe51e04..aa925972 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -252,11 +252,6 @@ class Trainer: else: self.run_get_model(self.config, get_model) - # TODO: out! - # init multispeaker settings of the model - if hasattr(self.model, "init_multispeaker"): - self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples) - # setup criterion self.criterion = self.get_criterion(self.model) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index e208c16c..60ef7276 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -218,7 +218,3 @@ class BaseTTSConfig(BaseTrainingConfig): lr_scheduler_params: dict = field(default_factory=lambda: {}) # testing test_sentences: List[str] = field(default_factory=lambda: []) - # multi-speaker - use_speaker_embedding: bool = False - use_d_vector_file: bool = False - d_vector_dim: int = 0 diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 39479231..c9475a6a 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -139,3 +139,36 @@ class VitsConfig(BaseTTSConfig): "Prior to November 22, 1963.", ] ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = None + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + if self.speaker_embedding_channels: + self.model_args.speaker_embedding_channels = self.speaker_embedding_channels + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 26d4e7fa..724ff342 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,4 +1,6 @@ import math +import os +import random from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Tuple @@ -14,7 +16,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.speakers import get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler @@ -180,6 +182,7 @@ class VitsArgs(Coqpit): speakers_file: str = None speaker_embedding_channels: int = 256 use_d_vector_file: bool = False + d_vector_file: str = None d_vector_dim: int = 0 detach_dp_input: bool = True @@ -315,27 +318,50 @@ class Vits(BaseTTS): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer or with external `d_vectors` computed from a speaker encoder model. - If you need a different behaviour, override this function for your model. - Args: config (Coqpit): Model configuration. data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ + self.embedded_speaker_dim = 0 if hasattr(config, "model_args"): config = config.model_args - self.embedded_speaker_dim = 0 - # init speaker manager - self.speaker_manager = get_speaker_manager(config, data=data) - if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0: - self.speaker_manager.num_speakers = config.num_speakers - self.num_speakers = self.speaker_manager.num_speakers - # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file: - self.embedded_speaker_dim = config.speaker_embedding_channels - self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels) - # init d-vector usage + + self.num_speakers = config.num_speakers + + if config.use_speaker_embedding: + self._init_speaker_embedding(config) + if config.use_d_vector_file: - self.embedded_speaker_dim = config.d_vector_dim + self._init_d_vector(config) + + def _init_speaker_embedding(self, config): + # pylint: disable=attribute-defined-outside-init + if config.speakers_file is not None: + self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path) + + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = config.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self, config): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + self.embedded_speaker_dim = config.d_vector_dim + + def on_init_start(self, trainer): + """Save the speaker.json at the beginning of the training. And update the config.json with the + speakers.json file path.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.json") + self.speaker_manager.save_speaker_ids_to_file(output_path) + trainer.config.speakers_file = output_path + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.json` is saved to {output_path}.") + print(f" > `speakers_file` is updated in the config.json.") @staticmethod def _set_cond_input(aux_input: Dict): @@ -349,6 +375,10 @@ class Vits(BaseTTS): g = aux_input["d_vectors"] return sid, g + def get_aux_input(self, aux_input: Dict): + sid, g = self._set_cond_input(aux_input) + return {"speaker_id": sid, "style_wav": None, "d_vector": g} + def forward( self, x: torch.tensor, @@ -633,7 +663,15 @@ class Vits(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = self.get_aux_input() + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), + "d_vector": None + if not self.config.use_d_vector_file + else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1), + "style_wav": None, + } for idx, sen in enumerate(test_sentences): wav, alignment, _, _ = synthesis( self, @@ -670,7 +708,7 @@ class Vits(BaseTTS): ) # add the speaker embedding layer if hasattr(self, "emb_g"): - gen_parameters = chain(gen_parameters, self.emb_g) + gen_parameters = chain(gen_parameters, self.emb_g.parameters()) optimizer0 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters )