mirror of https://github.com/coqui-ai/TTS.git
Use speaker_encoder from speaker manager in Vits
This commit is contained in:
parent
4d13b887f5
commit
d29c3780d1
|
@ -3,7 +3,7 @@ import torch
|
|||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
@ -258,7 +258,6 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
|
|
|
@ -406,42 +406,32 @@ class Vits(BaseTTS):
|
|||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
self.speaker_manager.init_speaker_encoder(
|
||||
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
||||
)
|
||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||
for param in self.speaker_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.speaker_manager.speaker_encoder.eval()
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if (
|
||||
hasattr(self.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
# TODO: change this with torchaudio Resample
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
||||
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
|
||||
self.config.audio["sample_rate"],
|
||||
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
else:
|
||||
self.audio_transform = None
|
||||
"""
|
||||
else:
|
||||
# self.audio_transform = None
|
||||
self.speaker_encoder = None
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if config.speakers_file is not None:
|
||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
|
||||
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = config.speaker_embedding_channels
|
||||
|
@ -451,7 +441,6 @@ class Vits(BaseTTS):
|
|||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
|
||||
def init_multilingual(self, config: Coqpit):
|
||||
|
@ -644,7 +633,7 @@ class Vits(BaseTTS):
|
|||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
|
||||
|
@ -653,7 +642,7 @@ class Vits(BaseTTS):
|
|||
"""if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)"""
|
||||
|
||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
||||
|
@ -1024,6 +1013,10 @@ class Vits(BaseTTS):
|
|||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
Loading…
Reference in New Issue