Refactor VITS multi-speaker initialization

This commit is contained in:
Eren Gölge 2021-10-15 10:20:00 +00:00
parent 0565457faa
commit 073a2d2eb0
4 changed files with 88 additions and 26 deletions

View File

@ -252,11 +252,6 @@ class Trainer:
else: else:
self.run_get_model(self.config, get_model) self.run_get_model(self.config, get_model)
# TODO: out!
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.model) self.criterion = self.get_criterion(self.model)

View File

@ -218,7 +218,3 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler_params: dict = field(default_factory=lambda: {}) lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing # testing
test_sentences: List[str] = field(default_factory=lambda: []) test_sentences: List[str] = field(default_factory=lambda: [])
# multi-speaker
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_dim: int = 0

View File

@ -139,3 +139,36 @@ class VitsConfig(BaseTTSConfig):
"Prior to November 22, 1963.", "Prior to November 22, 1963.",
] ]
) )
# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = None
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
if self.speaker_embedding_channels:
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file

View File

@ -1,4 +1,6 @@
import math import math
import os
import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -14,7 +16,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -180,6 +182,7 @@ class VitsArgs(Coqpit):
speakers_file: str = None speakers_file: str = None
speaker_embedding_channels: int = 256 speaker_embedding_channels: int = 256
use_d_vector_file: bool = False use_d_vector_file: bool = False
d_vector_file: str = None
d_vector_dim: int = 0 d_vector_dim: int = 0
detach_dp_input: bool = True detach_dp_input: bool = True
@ -315,27 +318,50 @@ class Vits(BaseTTS):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model. or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None. data (List, optional): Dataset items to infer number of speakers. Defaults to None.
""" """
self.embedded_speaker_dim = 0
if hasattr(config, "model_args"): if hasattr(config, "model_args"):
config = config.model_args config = config.model_args
self.embedded_speaker_dim = 0
# init speaker manager self.num_speakers = config.num_speakers
self.speaker_manager = get_speaker_manager(config, data=data)
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0: if config.use_speaker_embedding:
self.speaker_manager.num_speakers = config.num_speakers self._init_speaker_embedding(config)
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
# init d-vector usage
if config.use_d_vector_file: if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim self._init_d_vector(config)
def _init_speaker_embedding(self, config):
# pylint: disable=attribute-defined-outside-init
if config.speakers_file is not None:
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path)
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
def _init_d_vector(self, config):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
trainer.config.speakers_file = output_path
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(f" > `speakers_file` is updated in the config.json.")
@staticmethod @staticmethod
def _set_cond_input(aux_input: Dict): def _set_cond_input(aux_input: Dict):
@ -349,6 +375,10 @@ class Vits(BaseTTS):
g = aux_input["d_vectors"] g = aux_input["d_vectors"]
return sid, g return sid, g
def get_aux_input(self, aux_input: Dict):
sid, g = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
def forward( def forward(
self, self,
x: torch.tensor, x: torch.tensor,
@ -633,7 +663,15 @@ class Vits(BaseTTS):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input() aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None,
}
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis( wav, alignment, _, _ = synthesis(
self, self,
@ -670,7 +708,7 @@ class Vits(BaseTTS):
) )
# add the speaker embedding layer # add the speaker embedding layer
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g) gen_parameters = chain(gen_parameters, self.emb_g.parameters())
optimizer0 = get_optimizer( optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
) )