Use speaker_encoder from speaker manager in Vits

This commit is contained in:
Eren Gölge 2021-12-16 14:56:34 +00:00
parent 4d13b887f5
commit d29c3780d1
2 changed files with 13 additions and 21 deletions

View File

@ -3,7 +3,7 @@ import torch
import torchaudio
from torch import nn
from TTS.utils.audio import TorchSTFT
# from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
@ -258,7 +258,6 @@ class ResNetSpeakerEncoder(nn.Module):
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):

View File

@ -406,42 +406,32 @@ class Vits(BaseTTS):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
self.speaker_manager.init_speaker_encoder(
config.speaker_encoder_model_path, config.speaker_encoder_config_path
)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
self.speaker_manager.speaker_encoder.eval()
print(" > External Speaker Encoder Loaded !!")
if (
hasattr(self.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
self.config.audio["sample_rate"],
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
)
# pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_encoder.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
else:
self.audio_transform = None
"""
else:
# self.audio_transform = None
self.speaker_encoder = None
def _init_speaker_embedding(self, config):
# pylint: disable=attribute-defined-outside-init
if config.speakers_file is not None:
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = config.speaker_embedding_channels
@ -451,7 +441,6 @@ class Vits(BaseTTS):
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
self.embedded_speaker_dim = config.d_vector_dim
def init_multilingual(self, config: Coqpit):
@ -644,7 +633,7 @@ class Vits(BaseTTS):
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
@ -653,7 +642,7 @@ class Vits(BaseTTS):
"""if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)"""
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
@ -1024,6 +1013,10 @@ class Vits(BaseTTS):
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution.
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
self.load_state_dict(state["model"])
if eval:
self.eval()