mirror of https://github.com/coqui-ai/TTS.git
Update Speaker Encoder models
This commit is contained in:
parent
2033e17c44
commit
638091f41d
|
@ -250,4 +250,4 @@ class SpeakerEncoderDataset(Dataset):
|
||||||
feats = torch.stack(feats)
|
feats = torch.stack(feats)
|
||||||
labels = torch.stack(labels)
|
labels = torch.stack(labels)
|
||||||
|
|
||||||
return feats.transpose(1, 2), labels
|
return feats, labels
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.speaker_encoder.models.resnet import PreEmphasis
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,9 +35,21 @@ class LSTMWithoutProjection(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LSTMSpeakerEncoder(nn.Module):
|
class LSTMSpeakerEncoder(nn.Module):
|
||||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
proj_dim=256,
|
||||||
|
lstm_dim=768,
|
||||||
|
num_lstm_layers=3,
|
||||||
|
use_lstm_with_projection=True,
|
||||||
|
use_torch_spec=False,
|
||||||
|
audio_config=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_lstm_with_projection = use_lstm_with_projection
|
self.use_lstm_with_projection = use_lstm_with_projection
|
||||||
|
self.use_torch_spec = use_torch_spec
|
||||||
|
self.audio_config = audio_config
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
# choise LSTM layer
|
# choise LSTM layer
|
||||||
if use_lstm_with_projection:
|
if use_lstm_with_projection:
|
||||||
|
@ -46,6 +60,38 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||||
|
|
||||||
|
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||||
|
|
||||||
|
if self.use_torch_spec:
|
||||||
|
self.torch_spec = torch.nn.Sequential(
|
||||||
|
PreEmphasis(audio_config["preemphasis"]),
|
||||||
|
# TorchSTFT(
|
||||||
|
# n_fft=audio_config["fft_size"],
|
||||||
|
# hop_length=audio_config["hop_length"],
|
||||||
|
# win_length=audio_config["win_length"],
|
||||||
|
# sample_rate=audio_config["sample_rate"],
|
||||||
|
# window="hamming_window",
|
||||||
|
# mel_fmin=0.0,
|
||||||
|
# mel_fmax=None,
|
||||||
|
# use_htk=True,
|
||||||
|
# do_amp_to_db=False,
|
||||||
|
# n_mels=audio_config["num_mels"],
|
||||||
|
# power=2.0,
|
||||||
|
# use_mel=True,
|
||||||
|
# mel_norm=None,
|
||||||
|
# )
|
||||||
|
torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=audio_config["sample_rate"],
|
||||||
|
n_fft=audio_config["fft_size"],
|
||||||
|
win_length=audio_config["win_length"],
|
||||||
|
hop_length=audio_config["hop_length"],
|
||||||
|
window_fn=torch.hamming_window,
|
||||||
|
n_mels=audio_config["num_mels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.torch_spec = None
|
||||||
|
|
||||||
self._init_layers()
|
self._init_layers()
|
||||||
|
|
||||||
def _init_layers(self):
|
def _init_layers(self):
|
||||||
|
@ -55,22 +101,33 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
elif "weight" in name:
|
elif "weight" in name:
|
||||||
nn.init.xavier_normal_(param)
|
nn.init.xavier_normal_(param)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, l2_norm=True):
|
||||||
# TODO: implement state passing for lstms
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||||
|
to compute the spectrogram on-the-fly.
|
||||||
|
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
if self.use_torch_spec:
|
||||||
|
x.squeeze_(1)
|
||||||
|
x = self.torch_spec(x)
|
||||||
|
x = self.instancenorm(x).transpose(1, 2)
|
||||||
d = self.layers(x)
|
d = self.layers(x)
|
||||||
if self.use_lstm_with_projection:
|
if self.use_lstm_with_projection:
|
||||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
d = d[:, -1]
|
||||||
else:
|
if l2_norm:
|
||||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x):
|
def inference(self, x, l2_norm=True):
|
||||||
d = self.layers.forward(x)
|
d = self.layers.forward(x, l2_norm=l2_norm)
|
||||||
if self.use_lstm_with_projection:
|
|
||||||
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
|
|
||||||
else:
|
|
||||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||||
|
|
|
@ -190,8 +190,19 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x, l2_norm=False):
|
def forward(self, x, l2_norm=False):
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||||
|
to compute the spectrogram on-the-fly.
|
||||||
|
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(N, 1, T_{in})`
|
||||||
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x.squeeze_(1)
|
||||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
x = self.torch_spec(x)
|
x = self.torch_spec(x)
|
||||||
|
@ -230,7 +241,11 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
def inference(self, x, l2_norm=False):
|
||||||
|
return self.forward(x, l2_norm)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a batch of utterances
|
Generate embeddings for a batch of utterances
|
||||||
x: 1xTxD
|
x: 1xTxD
|
||||||
|
@ -254,7 +269,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
frames_batch.append(frames)
|
frames_batch.append(frames)
|
||||||
|
|
||||||
frames_batch = torch.cat(frames_batch, dim=0)
|
frames_batch = torch.cat(frames_batch, dim=0)
|
||||||
embeddings = self.forward(frames_batch, l2_norm=True)
|
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||||
|
|
||||||
if return_mean:
|
if return_mean:
|
||||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||||
|
|
|
@ -177,6 +177,8 @@ def setup_speaker_encoder_model(config: "Coqpit"):
|
||||||
config.model_params["proj_dim"],
|
config.model_params["proj_dim"],
|
||||||
config.model_params["lstm_dim"],
|
config.model_params["lstm_dim"],
|
||||||
config.model_params["num_lstm_layers"],
|
config.model_params["num_lstm_layers"],
|
||||||
|
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||||
|
audio_config=config.audio,
|
||||||
)
|
)
|
||||||
elif config.model_params["model_name"].lower() == "resnet":
|
elif config.model_params["model_name"].lower() == "resnet":
|
||||||
model = ResNetSpeakerEncoder(
|
model = ResNetSpeakerEncoder(
|
||||||
|
|
Loading…
Reference in New Issue