mirror of https://github.com/coqui-ai/TTS.git
Add H/ASP original checkpoint support
This commit is contained in:
parent
0bdfd3cb50
commit
9b011b1cb3
|
@ -1,9 +1,23 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchaudio
|
||||
import torch.nn as nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
class PreEmphasis(torch.nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer(
|
||||
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
|
@ -70,12 +84,17 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
@ -88,6 +107,14 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
|
@ -140,9 +167,13 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
x = x.transpose(1, 2)
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
else:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
@ -180,6 +211,10 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config['hop_length']
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
|
|
|
@ -179,7 +179,11 @@ def setup_model(c):
|
|||
c.model_params["num_lstm_layers"],
|
||||
)
|
||||
elif c.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
|
||||
log_input=c.model_params.get("log_input", False),
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
audio_config=c.audio
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -288,12 +288,16 @@ class SpeakerManager:
|
|||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
spec = torch.from_numpy(spec.T)
|
||||
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input.T)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
if self.use_cuda:
|
||||
spec = spec.cuda()
|
||||
spec = spec.unsqueeze(0)
|
||||
d_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
m_input = m_input.cuda()
|
||||
m_input = m_input.unsqueeze(0)
|
||||
d_vector = self.speaker_encoder.compute_embedding(m_input)
|
||||
return d_vector
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
|
|
Loading…
Reference in New Issue