mirror of https://github.com/coqui-ai/TTS.git
Refactor VITS multi-speaker initialization
This commit is contained in:
parent
0565457faa
commit
073a2d2eb0
|
@ -252,11 +252,6 @@ class Trainer:
|
|||
else:
|
||||
self.run_get_model(self.config, get_model)
|
||||
|
||||
# TODO: out!
|
||||
# init multispeaker settings of the model
|
||||
if hasattr(self.model, "init_multispeaker"):
|
||||
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
||||
|
|
|
@ -218,7 +218,3 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
# multi-speaker
|
||||
use_speaker_embedding: bool = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
|
|
@ -139,3 +139,36 @@ class VitsConfig(BaseTTSConfig):
|
|||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
||||
# multi-speaker settings
|
||||
# use speaker embedding layer
|
||||
num_speakers: int = 0
|
||||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
d_vector_dim: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||
if self.num_speakers > 0:
|
||||
self.model_args.num_speakers = self.num_speakers
|
||||
|
||||
# speaker embedding settings
|
||||
if self.use_speaker_embedding:
|
||||
self.model_args.use_speaker_embedding = True
|
||||
if self.speakers_file:
|
||||
self.model_args.speakers_file = self.speakers_file
|
||||
if self.speaker_embedding_channels:
|
||||
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
|
||||
|
||||
# d-vector settings
|
||||
if self.use_d_vector_file:
|
||||
self.model_args.use_d_vector_file = True
|
||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||
self.model_args.d_vector_dim = self.d_vector_dim
|
||||
if self.d_vector_file:
|
||||
self.model_args.d_vector_file = self.d_vector_file
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
@ -14,7 +16,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
|||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -180,6 +182,7 @@ class VitsArgs(Coqpit):
|
|||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = None
|
||||
d_vector_dim: int = 0
|
||||
detach_dp_input: bool = True
|
||||
|
||||
|
@ -315,27 +318,50 @@ class Vits(BaseTTS):
|
|||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
if hasattr(config, "model_args"):
|
||||
config = config.model_args
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
|
||||
self.speaker_manager.num_speakers = config.num_speakers
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
|
||||
# init d-vector usage
|
||||
|
||||
self.num_speakers = config.num_speakers
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
self._init_speaker_embedding(config)
|
||||
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
self._init_d_vector(config)
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if config.speakers_file is not None:
|
||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path)
|
||||
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
||||
speakers.json file path."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||
trainer.config.speakers_file = output_path
|
||||
trainer.config.model_args.speakers_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.json` is saved to {output_path}.")
|
||||
print(f" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
|
@ -349,6 +375,10 @@ class Vits(BaseTTS):
|
|||
g = aux_input["d_vectors"]
|
||||
return sid, g
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.tensor,
|
||||
|
@ -633,7 +663,15 @@ class Vits(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
"d_vector": None
|
||||
if not self.config.use_d_vector_file
|
||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||
"style_wav": None,
|
||||
}
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
|
@ -670,7 +708,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
# add the speaker embedding layer
|
||||
if hasattr(self, "emb_g"):
|
||||
gen_parameters = chain(gen_parameters, self.emb_g)
|
||||
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
|
||||
optimizer0 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue