mirror of https://github.com/coqui-ai/TTS.git
Update Speaker Encoder models
This commit is contained in:
parent
2033e17c44
commit
638091f41d
|
@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
|
|||
feats = torch.stack(feats)
|
||||
labels = torch.stack(labels)
|
||||
|
||||
return feats.transpose(1, 2), labels
|
||||
return feats, labels
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from TTS.speaker_encoder.models.resnet import PreEmphasis
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module):
|
|||
|
||||
|
||||
class LSTMSpeakerEncoder(nn.Module):
|
||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
proj_dim=256,
|
||||
lstm_dim=768,
|
||||
num_lstm_layers=3,
|
||||
use_lstm_with_projection=True,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
|
||||
layers = []
|
||||
# choise LSTM layer
|
||||
if use_lstm_with_projection:
|
||||
|
@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
else:
|
||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
|
@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
elif "weight" in name:
|
||||
nn.init.xavier_normal_(param)
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: implement state passing for lstms
|
||||
def forward(self, x, l2_norm=True):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x.squeeze_(1)
|
||||
x = self.torch_spec(x)
|
||||
x = self.instancenorm(x).transpose(1, 2)
|
||||
d = self.layers(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
||||
else:
|
||||
d = d[:, -1]
|
||||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x):
|
||||
d = self.layers.forward(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
||||
else:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
def inference(self, x, l2_norm=True):
|
||||
d = self.layers.forward(x, l2_norm=l2_norm)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
|
|
|
@ -190,8 +190,19 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
@ -230,7 +241,11 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
def inference(self, x, l2_norm=False):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
|
@ -254,7 +269,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.forward(frames_batch, l2_norm=True)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
|
|
@ -177,6 +177,8 @@ def setup_speaker_encoder_model(config: "Coqpit"):
|
|||
config.model_params["proj_dim"],
|
||||
config.model_params["lstm_dim"],
|
||||
config.model_params["num_lstm_layers"],
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
elif config.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(
|
||||
|
|
Loading…
Reference in New Issue