From 638091f41d289908e1f0a435eb26668c06e3ef05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Dec 2021 12:02:06 +0000 Subject: [PATCH] Update Speaker Encoder models --- TTS/speaker_encoder/dataset.py | 2 +- TTS/speaker_encoder/models/lstm.py | 79 +++++++++++++++++++--- TTS/speaker_encoder/models/resnet.py | 19 +++++- TTS/speaker_encoder/utils/generic_utils.py | 2 + 4 files changed, 88 insertions(+), 14 deletions(-) diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 6b2b0dd4..5b0fee22 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset): feats = torch.stack(feats) labels = torch.stack(labels) - return feats.transpose(1, 2), labels + return feats, labels diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index de5bb007..3c2eafee 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -1,7 +1,9 @@ import numpy as np import torch +import torchaudio from torch import nn +from TTS.speaker_encoder.models.resnet import PreEmphasis from TTS.utils.io import load_fsspec @@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module): class LSTMSpeakerEncoder(nn.Module): - def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): super().__init__() self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + layers = [] # choise LSTM layer if use_lstm_with_projection: @@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module): else: self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + else: + self.torch_spec = None + self._init_layers() def _init_layers(self): @@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module): elif "weight" in name: nn.init.xavier_normal_(param) - def forward(self, x): - # TODO: implement state passing for lstms + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) d = self.layers(x) if self.use_lstm_with_projection: - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) - else: + d = d[:, -1] + if l2_norm: d = torch.nn.functional.normalize(d, p=2, dim=1) return d @torch.no_grad() - def inference(self, x): - d = self.layers.forward(x) - if self.use_lstm_with_projection: - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) - else: - d = torch.nn.functional.normalize(d, p=2, dim=1) + def inference(self, x, l2_norm=True): + d = self.layers.forward(x, l2_norm=l2_norm) return d def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 7a384ef5..f1f13df1 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -190,8 +190,19 @@ class ResNetSpeakerEncoder(nn.Module): return out def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` + """ with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): + x.squeeze_(1) # if you torch spec compute it otherwise use the mel spec computed by the AP if self.use_torch_spec: x = self.torch_spec(x) @@ -230,7 +241,11 @@ class ResNetSpeakerEncoder(nn.Module): return x @torch.no_grad() - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + def inference(self, x, l2_norm=False): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): """ Generate embeddings for a batch of utterances x: 1xTxD @@ -254,7 +269,7 @@ class ResNetSpeakerEncoder(nn.Module): frames_batch.append(frames) frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.forward(frames_batch, l2_norm=True) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) if return_mean: embeddings = torch.mean(embeddings, dim=0, keepdim=True) diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index dab79f3c..b8aa4093 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -177,6 +177,8 @@ def setup_speaker_encoder_model(config: "Coqpit"): config.model_params["proj_dim"], config.model_params["lstm_dim"], config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, ) elif config.model_params["model_name"].lower() == "resnet": model = ResNetSpeakerEncoder(