From 9b011b1cb3849150d4cdb2ee06d022771ad7aee6 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 1 Sep 2021 09:23:45 -0300 Subject: [PATCH] Add H/ASP original checkpoint support --- TTS/speaker_encoder/models/resnet.py | 39 ++++++++++++++++++++-- TTS/speaker_encoder/utils/generic_utils.py | 6 +++- TTS/tts/utils/speakers.py | 14 +++++--- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index fcc850d7..beeb5ae1 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,9 +1,23 @@ import numpy as np import torch -from torch import nn +import torchaudio +import torch.nn as nn from TTS.utils.io import load_fsspec +class PreEmphasis(torch.nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer( + 'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0) + ) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect') + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) class SELayer(nn.Module): def __init__(self, channel, reduction=8): @@ -70,12 +84,17 @@ class ResNetSpeakerEncoder(nn.Module): num_filters=[32, 64, 128, 256], encoder_type="ASP", log_input=False, + use_torch_spec=False, + audio_config=None, ): super(ResNetSpeakerEncoder, self).__init__() self.encoder_type = encoder_type self.input_dim = input_dim self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU(inplace=True) self.bn1 = nn.BatchNorm2d(num_filters[0]) @@ -88,6 +107,14 @@ class ResNetSpeakerEncoder(nn.Module): self.instancenorm = nn.InstanceNorm1d(input_dim) + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"]) + ) + else: + self.torch_spec = None + outmap_size = int(self.input_dim / 8) self.attention = nn.Sequential( @@ -140,9 +167,13 @@ class ResNetSpeakerEncoder(nn.Module): return out def forward(self, x, l2_norm=False): - x = x.transpose(1, 2) with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x = self.torch_spec(x) + else: + x = x.transpose(1, 2) + if self.log_input: x = (x + 1e-6).log() x = self.instancenorm(x).unsqueeze(1) @@ -180,6 +211,10 @@ class ResNetSpeakerEncoder(nn.Module): Generate embeddings for a batch of utterances x: 1xTxD """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config['hop_length'] + max_len = x.shape[1] if max_len < num_frames: diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 1981fbe9..3714e3c4 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -179,7 +179,11 @@ def setup_model(c): c.model_params["num_lstm_layers"], ) elif c.model_params["model_name"].lower() == "resnet": - model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) + model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"], + log_input=c.model_params.get("log_input", False), + use_torch_spec=c.model_params.get("use_torch_spec", False), + audio_config=c.audio + ) return model diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 1497ca74..282875af 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -288,12 +288,16 @@ class SpeakerManager: def _compute(wav_file: str): waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) - spec = self.speaker_encoder_ap.melspectrogram(waveform) - spec = torch.from_numpy(spec.T) + if not self.speaker_encoder_config.model_params.get("use_torch_spec", False): + m_input = self.speaker_encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input.T) + else: + m_input = torch.from_numpy(waveform) + if self.use_cuda: - spec = spec.cuda() - spec = spec.unsqueeze(0) - d_vector = self.speaker_encoder.compute_embedding(spec) + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + d_vector = self.speaker_encoder.compute_embedding(m_input) return d_vector if isinstance(wav_file, list):