mirror of https://github.com/coqui-ai/TTS.git
Remove torchaudio requeriment
This commit is contained in:
parent
2e516869a1
commit
d39200e69b
|
@ -1,7 +1,10 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
from torch import nn
|
||||||
import torch.nn as nn
|
|
||||||
|
# import torchaudio
|
||||||
|
|
||||||
|
from TTS.utils.audio import TorchSTFT
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -110,14 +113,29 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
self.torch_spec = torch.nn.Sequential(
|
self.torch_spec = torch.nn.Sequential(
|
||||||
PreEmphasis(audio_config["preemphasis"]),
|
PreEmphasis(audio_config["preemphasis"]),
|
||||||
torchaudio.transforms.MelSpectrogram(
|
TorchSTFT(
|
||||||
|
n_fft=audio_config["fft_size"],
|
||||||
|
hop_length=audio_config["hop_length"],
|
||||||
|
win_length=audio_config["win_length"],
|
||||||
|
sample_rate=audio_config["sample_rate"],
|
||||||
|
window="hamming_window",
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=None,
|
||||||
|
use_htk=True,
|
||||||
|
do_amp_to_db=False,
|
||||||
|
n_mels=audio_config["num_mels"],
|
||||||
|
power=2.0,
|
||||||
|
use_mel=True,
|
||||||
|
mel_norm=None
|
||||||
|
),
|
||||||
|
'''torchaudio.transforms.MelSpectrogram(
|
||||||
sample_rate=audio_config["sample_rate"],
|
sample_rate=audio_config["sample_rate"],
|
||||||
n_fft=audio_config["fft_size"],
|
n_fft=audio_config["fft_size"],
|
||||||
win_length=audio_config["win_length"],
|
win_length=audio_config["win_length"],
|
||||||
hop_length=audio_config["hop_length"],
|
hop_length=audio_config["hop_length"],
|
||||||
window_fn=torch.hamming_window,
|
window_fn=torch.hamming_window,
|
||||||
n_mels=audio_config["num_mels"],
|
n_mels=audio_config["num_mels"],
|
||||||
),
|
),'''
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.torch_spec = None
|
self.torch_spec = None
|
||||||
|
|
|
@ -4,7 +4,7 @@ from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
# import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
@ -395,7 +395,7 @@ class Vits(BaseTTS):
|
||||||
if config.use_speaker_encoder_as_loss:
|
if config.use_speaker_encoder_as_loss:
|
||||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||||
)
|
)
|
||||||
self.speaker_manager.init_speaker_encoder(
|
self.speaker_manager.init_speaker_encoder(
|
||||||
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
||||||
|
@ -410,14 +410,17 @@ class Vits(BaseTTS):
|
||||||
hasattr(self.speaker_encoder, "audio_config")
|
hasattr(self.speaker_encoder, "audio_config")
|
||||||
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
||||||
):
|
):
|
||||||
self.audio_transform = torchaudio.transforms.Resample(
|
raise RuntimeError(
|
||||||
|
" [!] To use the speaker consistency loss (SCL) you need to have the TTS model sampling rate ({}) equal to the speaker encoder sampling rate ({}) !".format(self.audio_config["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
|
||||||
|
)
|
||||||
|
'''self.audio_transform = torchaudio.transforms.Resample(
|
||||||
orig_freq=self.audio_config["sample_rate"],
|
orig_freq=self.audio_config["sample_rate"],
|
||||||
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.audio_transform = None
|
self.audio_transform = None'''
|
||||||
else:
|
else:
|
||||||
self.audio_transform = None
|
# self.audio_transform = None
|
||||||
self.speaker_encoder = None
|
self.speaker_encoder = None
|
||||||
|
|
||||||
def _init_speaker_embedding(self, config):
|
def _init_speaker_embedding(self, config):
|
||||||
|
@ -655,8 +658,8 @@ class Vits(BaseTTS):
|
||||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||||
|
|
||||||
# resample audio to speaker encoder sample_rate
|
# resample audio to speaker encoder sample_rate
|
||||||
if self.audio_transform is not None:
|
'''if self.audio_transform is not None:
|
||||||
wavs_batch = self.audio_transform(wavs_batch)
|
wavs_batch = self.audio_transform(wavs_batch)'''
|
||||||
|
|
||||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
use_mel=False,
|
use_mel=False,
|
||||||
do_amp_to_db=False,
|
do_amp_to_db=False,
|
||||||
spec_gain=1.0,
|
spec_gain=1.0,
|
||||||
|
power=None,
|
||||||
|
use_htk=False,
|
||||||
|
mel_norm="slaney"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -45,6 +48,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
self.use_mel = use_mel
|
self.use_mel = use_mel
|
||||||
self.do_amp_to_db = do_amp_to_db
|
self.do_amp_to_db = do_amp_to_db
|
||||||
self.spec_gain = spec_gain
|
self.spec_gain = spec_gain
|
||||||
|
self.power = power
|
||||||
|
self.use_htk = use_htk
|
||||||
|
self.mel_norm = mel_norm
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||||
self.mel_basis = None
|
self.mel_basis = None
|
||||||
if use_mel:
|
if use_mel:
|
||||||
|
@ -83,6 +89,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
M = o[:, :, :, 0]
|
M = o[:, :, :, 0]
|
||||||
P = o[:, :, :, 1]
|
P = o[:, :, :, 1]
|
||||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||||
|
|
||||||
|
if self.power is not None:
|
||||||
|
S = S ** self.power
|
||||||
|
|
||||||
if self.use_mel:
|
if self.use_mel:
|
||||||
S = torch.matmul(self.mel_basis.to(x), S)
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
if self.do_amp_to_db:
|
if self.do_amp_to_db:
|
||||||
|
@ -91,7 +101,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self):
|
||||||
mel_basis = librosa.filters.mel(
|
mel_basis = librosa.filters.mel(
|
||||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
|
|
@ -26,5 +26,3 @@ unidic-lite==1.0.8
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||||
fsspec>=2021.04.0
|
fsspec>=2021.04.0
|
||||||
pyworld
|
pyworld
|
||||||
webrtcvad
|
|
||||||
torchaudio>=0.7
|
|
||||||
|
|
Loading…
Reference in New Issue