diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 92d34494..7a384ef5 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -3,7 +3,7 @@ import torch import torchaudio from torch import nn -from TTS.utils.audio import TorchSTFT +# from TTS.utils.audio import TorchSTFT from TTS.utils.io import load_fsspec @@ -258,7 +258,6 @@ class ResNetSpeakerEncoder(nn.Module): if return_mean: embeddings = torch.mean(embeddings, dim=0, keepdim=True) - return embeddings def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7f83f452..ddf6800f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -406,42 +406,32 @@ class Vits(BaseTTS): raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" ) - self.speaker_manager.init_speaker_encoder( - config.speaker_encoder_model_path, config.speaker_encoder_config_path - ) - self.speaker_encoder = self.speaker_manager.speaker_encoder.train() - for param in self.speaker_encoder.parameters(): - param.requires_grad = False + self.speaker_manager.speaker_encoder.eval() print(" > External Speaker Encoder Loaded !!") if ( - hasattr(self.speaker_encoder, "audio_config") - and self.config.audio["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"] + hasattr(self.speaker_manager.speaker_encoder, "audio_config") + and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] ): # TODO: change this with torchaudio Resample raise RuntimeError( " [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format( - self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"] + self.config.audio["sample_rate"], + self.speaker_manager.speaker_encoder.audio_config["sample_rate"], ) ) # pylint: disable=W0101,W0105 """ self.audio_transform = torchaudio.transforms.Resample( orig_freq=self.audio_config["sample_rate"], - new_freq=self.speaker_encoder.audio_config["sample_rate"], + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], ) else: self.audio_transform = None """ - else: - # self.audio_transform = None - self.speaker_encoder = None def _init_speaker_embedding(self, config): # pylint: disable=attribute-defined-outside-init - if config.speakers_file is not None: - self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) - if self.num_speakers > 0: print(" > initialization of speaker-embedding layers.") self.embedded_speaker_dim = config.speaker_embedding_channels @@ -451,7 +441,6 @@ class Vits(BaseTTS): # pylint: disable=attribute-defined-outside-init if hasattr(self, "emb_g"): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") - self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.embedded_speaker_dim = config.d_vector_dim def init_multilingual(self, config: Coqpit): @@ -644,7 +633,7 @@ class Vits(BaseTTS): self.args.spec_segment_size * self.config.audio.hop_length, ) - if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) @@ -653,7 +642,7 @@ class Vits(BaseTTS): """if self.audio_transform is not None: wavs_batch = self.audio_transform(wavs_batch)""" - pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) + pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True) # split generated and GT speaker embeddings gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) @@ -1024,6 +1013,10 @@ class Vits(BaseTTS): ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} self.load_state_dict(state["model"]) if eval: self.eval()