mirror of https://github.com/coqui-ai/TTS.git
Refactor VITS multi-speaker initialization
This commit is contained in:
parent
0565457faa
commit
073a2d2eb0
|
@ -252,11 +252,6 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.run_get_model(self.config, get_model)
|
self.run_get_model(self.config, get_model)
|
||||||
|
|
||||||
# TODO: out!
|
|
||||||
# init multispeaker settings of the model
|
|
||||||
if hasattr(self.model, "init_multispeaker"):
|
|
||||||
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
|
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
|
||||||
|
|
|
@ -218,7 +218,3 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||||
# multi-speaker
|
|
||||||
use_speaker_embedding: bool = False
|
|
||||||
use_d_vector_file: bool = False
|
|
||||||
d_vector_dim: int = 0
|
|
||||||
|
|
|
@ -139,3 +139,36 @@ class VitsConfig(BaseTTSConfig):
|
||||||
"Prior to November 22, 1963.",
|
"Prior to November 22, 1963.",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# multi-speaker settings
|
||||||
|
# use speaker embedding layer
|
||||||
|
num_speakers: int = 0
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
speakers_file: str = None
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
|
|
||||||
|
# use d-vectors
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: str = False
|
||||||
|
d_vector_dim: int = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||||
|
if self.num_speakers > 0:
|
||||||
|
self.model_args.num_speakers = self.num_speakers
|
||||||
|
|
||||||
|
# speaker embedding settings
|
||||||
|
if self.use_speaker_embedding:
|
||||||
|
self.model_args.use_speaker_embedding = True
|
||||||
|
if self.speakers_file:
|
||||||
|
self.model_args.speakers_file = self.speakers_file
|
||||||
|
if self.speaker_embedding_channels:
|
||||||
|
self.model_args.speaker_embedding_channels = self.speaker_embedding_channels
|
||||||
|
|
||||||
|
# d-vector settings
|
||||||
|
if self.use_d_vector_file:
|
||||||
|
self.model_args.use_d_vector_file = True
|
||||||
|
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||||
|
self.model_args.d_vector_dim = self.d_vector_dim
|
||||||
|
if self.d_vector_file:
|
||||||
|
self.model_args.d_vector_file = self.d_vector_file
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
@ -14,7 +16,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
@ -180,6 +182,7 @@ class VitsArgs(Coqpit):
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
speaker_embedding_channels: int = 256
|
speaker_embedding_channels: int = 256
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: str = None
|
||||||
d_vector_dim: int = 0
|
d_vector_dim: int = 0
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
|
|
||||||
|
@ -315,27 +318,50 @@ class Vits(BaseTTS):
|
||||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
or with external `d_vectors` computed from a speaker encoder model.
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
If you need a different behaviour, override this function for your model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
self.embedded_speaker_dim = 0
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config = config.model_args
|
config = config.model_args
|
||||||
self.embedded_speaker_dim = 0
|
|
||||||
# init speaker manager
|
self.num_speakers = config.num_speakers
|
||||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
|
||||||
if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
|
if config.use_speaker_embedding:
|
||||||
self.speaker_manager.num_speakers = config.num_speakers
|
self._init_speaker_embedding(config)
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
|
||||||
# init speaker embedding layer
|
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
|
||||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
|
||||||
self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
|
|
||||||
# init d-vector usage
|
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
self._init_d_vector(config)
|
||||||
|
|
||||||
|
def _init_speaker_embedding(self, config):
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
if config.speakers_file is not None:
|
||||||
|
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path)
|
||||||
|
|
||||||
|
if self.num_speakers > 0:
|
||||||
|
print(" > initialization of speaker-embedding layers.")
|
||||||
|
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||||
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
|
||||||
|
def _init_d_vector(self, config):
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
if hasattr(self, "emb_g"):
|
||||||
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||||
|
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||||
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
|
|
||||||
|
def on_init_start(self, trainer):
|
||||||
|
"""Save the speaker.json at the beginning of the training. And update the config.json with the
|
||||||
|
speakers.json file path."""
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
output_path = os.path.join(trainer.output_path, "speakers.json")
|
||||||
|
self.speaker_manager.save_speaker_ids_to_file(output_path)
|
||||||
|
trainer.config.speakers_file = output_path
|
||||||
|
trainer.config.model_args.speakers_file = output_path
|
||||||
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
|
print(f" > `speakers.json` is saved to {output_path}.")
|
||||||
|
print(f" > `speakers_file` is updated in the config.json.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
|
@ -349,6 +375,10 @@ class Vits(BaseTTS):
|
||||||
g = aux_input["d_vectors"]
|
g = aux_input["d_vectors"]
|
||||||
return sid, g
|
return sid, g
|
||||||
|
|
||||||
|
def get_aux_input(self, aux_input: Dict):
|
||||||
|
sid, g = self._set_cond_input(aux_input)
|
||||||
|
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.tensor,
|
x: torch.tensor,
|
||||||
|
@ -633,7 +663,15 @@ class Vits(BaseTTS):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = self.get_aux_input()
|
aux_inputs = {
|
||||||
|
"speaker_id": None
|
||||||
|
if not self.config.use_speaker_embedding
|
||||||
|
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||||
|
"d_vector": None
|
||||||
|
if not self.config.use_d_vector_file
|
||||||
|
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||||
|
"style_wav": None,
|
||||||
|
}
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
wav, alignment, _, _ = synthesis(
|
wav, alignment, _, _ = synthesis(
|
||||||
self,
|
self,
|
||||||
|
@ -670,7 +708,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
# add the speaker embedding layer
|
# add the speaker embedding layer
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
gen_parameters = chain(gen_parameters, self.emb_g)
|
gen_parameters = chain(gen_parameters, self.emb_g.parameters())
|
||||||
optimizer0 = get_optimizer(
|
optimizer0 = get_optimizer(
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue