Refactor VITS multi-speaker initialization

This commit is contained in:
Eren Gölge 2021-10-15 10:20:00 +00:00
parent 0565457faa
commit 073a2d2eb0
4 changed files with 88 additions and 26 deletions

View File

@ -252,11 +252,6 @@ class Trainer:
else:
self.run_get_model(self.config, get_model)
# TODO: out!
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
# setup criterion
self.criterion = self.get_criterion(self.model)

View File

@ -218,7 +218,3 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda: [])
# multi-speaker
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_dim: int = 0

View File

@ -139,3 +139,36 @@ class VitsConfig(BaseTTSConfig):
"Prior to November 22, 1963.",
]
)
# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = None
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
if self.speaker_embedding_channels:
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file

View File

@ -1,4 +1,6 @@
import math
import os
import random
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple
@ -14,7 +16,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -180,6 +182,7 @@ class VitsArgs(Coqpit):
speakers_file: str = None
speaker_embedding_channels: int = 256
use_d_vector_file: bool = False
d_vector_file: str = None
d_vector_dim: int = 0
detach_dp_input: bool = True
@ -315,27 +318,50 @@ class Vits(BaseTTS):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
if hasattr(config, "model_args"):
config = config.model_args
self.embedded_speaker_dim = 0
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
self.speaker_manager.num_speakers = config.num_speakers
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
# init d-vector usage
self.num_speakers = config.num_speakers
if config.use_speaker_embedding:
self._init_speaker_embedding(config)
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
self._init_d_vector(config)
def _init_speaker_embedding(self, config):
# pylint: disable=attribute-defined-outside-init
if config.speakers_file is not None:
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path)
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = config.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
def _init_d_vector(self, config):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
trainer.config.speakers_file = output_path
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(f" > `speakers_file` is updated in the config.json.")
@staticmethod
def _set_cond_input(aux_input: Dict):
@ -349,6 +375,10 @@ class Vits(BaseTTS):
g = aux_input["d_vectors"]
return sid, g
def get_aux_input(self, aux_input: Dict):
sid, g = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
def forward(
self,
x: torch.tensor,
@ -633,7 +663,15 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None,
}
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
@ -670,7 +708,7 @@ class Vits(BaseTTS):
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g)
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)