Add H/ASP original checkpoint support

This commit is contained in:
Edresson 2021-09-01 09:23:45 -03:00 committed by Eren Gölge
parent 0bdfd3cb50
commit 9b011b1cb3
3 changed files with 51 additions and 8 deletions

View File

@ -1,9 +1,23 @@
import numpy as np import numpy as np
import torch import torch
from torch import nn import torchaudio
import torch.nn as nn
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer(
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
@ -70,12 +84,17 @@ class ResNetSpeakerEncoder(nn.Module):
num_filters=[32, 64, 128, 256], num_filters=[32, 64, 128, 256],
encoder_type="ASP", encoder_type="ASP",
log_input=False, log_input=False,
use_torch_spec=False,
audio_config=None,
): ):
super(ResNetSpeakerEncoder, self).__init__() super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type self.encoder_type = encoder_type
self.input_dim = input_dim self.input_dim = input_dim
self.log_input = log_input self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0]) self.bn1 = nn.BatchNorm2d(num_filters[0])
@ -88,6 +107,14 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim) self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8) outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential( self.attention = nn.Sequential(
@ -140,9 +167,13 @@ class ResNetSpeakerEncoder(nn.Module):
return out return out
def forward(self, x, l2_norm=False): def forward(self, x, l2_norm=False):
x = x.transpose(1, 2)
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x = self.torch_spec(x)
else:
x = x.transpose(1, 2)
if self.log_input: if self.log_input:
x = (x + 1e-6).log() x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1) x = self.instancenorm(x).unsqueeze(1)
@ -180,6 +211,10 @@ class ResNetSpeakerEncoder(nn.Module):
Generate embeddings for a batch of utterances Generate embeddings for a batch of utterances
x: 1xTxD x: 1xTxD
""" """
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config['hop_length']
max_len = x.shape[1] max_len = x.shape[1]
if max_len < num_frames: if max_len < num_frames:

View File

@ -179,7 +179,11 @@ def setup_model(c):
c.model_params["num_lstm_layers"], c.model_params["num_lstm_layers"],
) )
elif c.model_params["model_name"].lower() == "resnet": elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
log_input=c.model_params.get("log_input", False),
use_torch_spec=c.model_params.get("use_torch_spec", False),
audio_config=c.audio
)
return model return model

View File

@ -288,12 +288,16 @@ class SpeakerManager:
def _compute(wav_file: str): def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform) if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
spec = torch.from_numpy(spec.T) m_input = self.speaker_encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input.T)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda: if self.use_cuda:
spec = spec.cuda() m_input = m_input.cuda()
spec = spec.unsqueeze(0) m_input = m_input.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(spec) d_vector = self.speaker_encoder.compute_embedding(m_input)
return d_vector return d_vector
if isinstance(wav_file, list): if isinstance(wav_file, list):