Update Speaker Encoder models

This commit is contained in:
Eren Gölge 2021-12-30 12:02:06 +00:00
parent 2033e17c44
commit 638091f41d
4 changed files with 88 additions and 14 deletions

View File

@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
feats = torch.stack(feats)
labels = torch.stack(labels)
return feats.transpose(1, 2), labels
return feats, labels

View File

@ -1,7 +1,9 @@
import numpy as np
import torch
import torchaudio
from torch import nn
from TTS.speaker_encoder.models.resnet import PreEmphasis
from TTS.utils.io import load_fsspec
@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module):
class LSTMSpeakerEncoder(nn.Module):
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
def __init__(
self,
input_dim,
proj_dim=256,
lstm_dim=768,
num_lstm_layers=3,
use_lstm_with_projection=True,
use_torch_spec=False,
audio_config=None,
):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
layers = []
# choise LSTM layer
if use_lstm_with_projection:
@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module):
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
self._init_layers()
def _init_layers(self):
@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module):
elif "weight" in name:
nn.init.xavier_normal_(param)
def forward(self, x):
# TODO: implement state passing for lstms
def forward(self, x, l2_norm=True):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x.squeeze_(1)
x = self.torch_spec(x)
x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
else:
d = d[:, -1]
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
def inference(self, x):
d = self.layers.forward(x)
if self.use_lstm_with_projection:
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
else:
d = torch.nn.functional.normalize(d, p=2, dim=1)
def inference(self, x, l2_norm=True):
d = self.layers.forward(x, l2_norm=l2_norm)
return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):

View File

@ -190,8 +190,19 @@ class ResNetSpeakerEncoder(nn.Module):
return out
def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
@ -230,7 +241,11 @@ class ResNetSpeakerEncoder(nn.Module):
return x
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
def inference(self, x, l2_norm=False):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
@ -254,7 +269,7 @@ class ResNetSpeakerEncoder(nn.Module):
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.forward(frames_batch, l2_norm=True)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)

View File

@ -177,6 +177,8 @@ def setup_speaker_encoder_model(config: "Coqpit"):
config.model_params["proj_dim"],
config.model_params["lstm_dim"],
config.model_params["num_lstm_layers"],
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
elif config.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(