From d39200e69b29b107302dcfc124a17faf0c0f6cdd Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 23 Nov 2021 11:24:36 -0300 Subject: [PATCH] Remove torchaudio requeriment --- TTS/speaker_encoder/models/resnet.py | 26 ++++++++++++++++++++++---- TTS/tts/models/vits.py | 21 ++++++++++++--------- TTS/utils/audio.py | 12 +++++++++++- requirements.txt | 2 -- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 47b6f23f..8f0a8809 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,7 +1,10 @@ import numpy as np import torch -import torchaudio -import torch.nn as nn +from torch import nn + +# import torchaudio + +from TTS.utils.audio import TorchSTFT from TTS.utils.io import load_fsspec @@ -110,14 +113,29 @@ class ResNetSpeakerEncoder(nn.Module): if self.use_torch_spec: self.torch_spec = torch.nn.Sequential( PreEmphasis(audio_config["preemphasis"]), - torchaudio.transforms.MelSpectrogram( + TorchSTFT( + n_fft=audio_config["fft_size"], + hop_length=audio_config["hop_length"], + win_length=audio_config["win_length"], + sample_rate=audio_config["sample_rate"], + window="hamming_window", + mel_fmin=0.0, + mel_fmax=None, + use_htk=True, + do_amp_to_db=False, + n_mels=audio_config["num_mels"], + power=2.0, + use_mel=True, + mel_norm=None + ), + '''torchaudio.transforms.MelSpectrogram( sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"], - ), + ),''' ) else: self.torch_spec = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ac0f5d69..4eb12b3b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -4,7 +4,7 @@ from itertools import chain from typing import Dict, List, Tuple import torch -import torchaudio +# import torchaudio from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -395,7 +395,7 @@ class Vits(BaseTTS): if config.use_speaker_encoder_as_loss: if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: raise RuntimeError( - " [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" ) self.speaker_manager.init_speaker_encoder( config.speaker_encoder_model_path, config.speaker_encoder_config_path @@ -410,14 +410,17 @@ class Vits(BaseTTS): hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"] ): - self.audio_transform = torchaudio.transforms.Resample( + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to have the TTS model sampling rate ({}) equal to the speaker encoder sampling rate ({}) !".format(self.audio_config["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]) + ) + '''self.audio_transform = torchaudio.transforms.Resample( orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"], - ) - else: - self.audio_transform = None + ) + else: + self.audio_transform = None''' else: - self.audio_transform = None + # self.audio_transform = None self.speaker_encoder = None def _init_speaker_embedding(self, config): @@ -655,8 +658,8 @@ class Vits(BaseTTS): wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) # resample audio to speaker encoder sample_rate - if self.audio_transform is not None: - wavs_batch = self.audio_transform(wavs_batch) + '''if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch)''' pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index e64b95e0..d650c288 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -32,6 +32,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method use_mel=False, do_amp_to_db=False, spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney" ): super().__init__() self.n_fft = n_fft @@ -45,6 +48,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method self.use_mel = use_mel self.do_amp_to_db = do_amp_to_db self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None if use_mel: @@ -83,6 +89,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method M = o[:, :, :, 0] P = o[:, :, :, 1] S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + + if self.power is not None: + S = S ** self.power + if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) if self.do_amp_to_db: @@ -91,7 +101,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method def _build_mel_basis(self): mel_basis = librosa.filters.mel( - self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm ) self.mel_basis = torch.from_numpy(mel_basis).float() diff --git a/requirements.txt b/requirements.txt index cf4798b2..3ec33ceb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,5 +26,3 @@ unidic-lite==1.0.8 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 fsspec>=2021.04.0 pyworld -webrtcvad -torchaudio>=0.7