mirror of https://github.com/coqui-ai/TTS.git
Implement training support with d_vecs in the VITS model
This commit is contained in:
parent
6a7db67a91
commit
d91c595c5a
|
@ -8,6 +8,7 @@ import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
|
@ -138,6 +139,9 @@ class VitsArgs(Coqpit):
|
||||||
use_d_vector_file (bool):
|
use_d_vector_file (bool):
|
||||||
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_file (str):
|
||||||
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
d_vector_dim (int):
|
d_vector_dim (int):
|
||||||
Number of d-vector channels. Defaults to 0.
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
|
@ -179,6 +183,7 @@ class VitsArgs(Coqpit):
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
num_speakers: int = 0
|
num_speakers: int = 0
|
||||||
speakers_file: str = None
|
speakers_file: str = None
|
||||||
|
d_vector_file: str = None
|
||||||
speaker_embedding_channels: int = 256
|
speaker_embedding_channels: int = 256
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
d_vector_file: str = None
|
d_vector_file: str = None
|
||||||
|
@ -360,7 +365,7 @@ class Vits(BaseTTS):
|
||||||
if sid.ndim == 0:
|
if sid.ndim == 0:
|
||||||
sid = sid.unsqueeze_(0)
|
sid = sid.unsqueeze_(0)
|
||||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||||
g = aux_input["d_vectors"]
|
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||||
return sid, g
|
return sid, g
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
|
@ -400,7 +405,7 @@ class Vits(BaseTTS):
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1 and sid is not None:
|
if self.num_speakers > 1 and sid is not None and not self.use_d_vector:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
|
|
|
@ -154,15 +154,21 @@ class SpeakerManager:
|
||||||
"""
|
"""
|
||||||
self._save_json(file_path, self.d_vectors)
|
self._save_json(file_path, self.d_vectors)
|
||||||
|
|
||||||
def set_d_vectors_from_file(self, file_path: str) -> None:
|
def set_d_vectors_from_file(self, file_path: str, data: List = None) -> None:
|
||||||
"""Load d_vectors from a json file.
|
"""Load d_vectors from a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the target json file.
|
file_path (str): Path to the target json file.
|
||||||
"""
|
"""
|
||||||
self.d_vectors = self._load_json(file_path)
|
self.d_vectors = self._load_json(file_path)
|
||||||
|
|
||||||
|
# load speakers from data, because during the training we can just use some speakers from d_vector_file
|
||||||
|
if data is not None:
|
||||||
|
self.speaker_ids, _ = self.parse_speakers_from_data(data)
|
||||||
|
else:
|
||||||
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||||
|
|
||||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||||
|
|
||||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||||
|
@ -357,7 +363,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
|
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
|
||||||
)
|
)
|
||||||
speaker_manager.load_d_vectors_file(c.d_vector_file)
|
speaker_manager.load_d_vectors_file(c.d_vector_file)
|
||||||
speaker_manager.set_d_vectors_from_file(speakers_file)
|
speaker_manager.set_d_vectors_from_file(speakers_file, data=data)
|
||||||
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
||||||
speaker_ids_from_data = speaker_manager.speaker_ids
|
speaker_ids_from_data = speaker_manager.speaker_ids
|
||||||
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
||||||
|
@ -366,7 +372,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
), " [!] You cannot introduce new speakers to a pre-trained model."
|
), " [!] You cannot introduce new speakers to a pre-trained model."
|
||||||
elif c.use_d_vector_file and c.d_vector_file:
|
elif c.use_d_vector_file and c.d_vector_file:
|
||||||
# new speaker manager with external speaker embeddings.
|
# new speaker manager with external speaker embeddings.
|
||||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
speaker_manager.set_d_vectors_from_file(c.d_vector_file, data=data)
|
||||||
elif c.use_d_vector_file and not c.d_vector_file:
|
elif c.use_d_vector_file and not c.d_vector_file:
|
||||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
||||||
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||||
|
|
Loading…
Reference in New Issue