Add H/ASP original checkpoint support

This commit is contained in:
Edresson 2021-09-01 09:23:45 -03:00 committed by Eren Gölge
parent 0bdfd3cb50
commit 9b011b1cb3
3 changed files with 51 additions and 8 deletions

View File

@ -1,9 +1,23 @@
import numpy as np
import torch
from torch import nn
import torchaudio
import torch.nn as nn
from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer(
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
@ -70,12 +84,17 @@ class ResNetSpeakerEncoder(nn.Module):
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
use_torch_spec=False,
audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
@ -88,6 +107,14 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
@ -140,9 +167,13 @@ class ResNetSpeakerEncoder(nn.Module):
return out
def forward(self, x, l2_norm=False):
x = x.transpose(1, 2)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x = self.torch_spec(x)
else:
x = x.transpose(1, 2)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
@ -180,6 +211,10 @@ class ResNetSpeakerEncoder(nn.Module):
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config['hop_length']
max_len = x.shape[1]
if max_len < num_frames:

View File

@ -179,7 +179,11 @@ def setup_model(c):
c.model_params["num_lstm_layers"],
)
elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
log_input=c.model_params.get("log_input", False),
use_torch_spec=c.model_params.get("use_torch_spec", False),
audio_config=c.audio
)
return model

View File

@ -288,12 +288,16 @@ class SpeakerManager:
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform)
spec = torch.from_numpy(spec.T)
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input.T)
else:
m_input = torch.from_numpy(waveform)
if self.use_cuda:
spec = spec.cuda()
spec = spec.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(spec)
m_input = m_input.cuda()
m_input = m_input.unsqueeze(0)
d_vector = self.speaker_encoder.compute_embedding(m_input)
return d_vector
if isinstance(wav_file, list):