diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 53df94a3..ed655a32 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -11,7 +11,6 @@ from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec - # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py index d22e301d..a330cb46 100644 --- a/TTS/tts/models/forward_tts_e2e.py +++ b/TTS/tts/models/forward_tts_e2e.py @@ -2,10 +2,9 @@ import os from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Tuple, Union + import numpy as np import pyworld as pw - - import torch import torch.distributed as dist from coqpit import Coqpit @@ -20,15 +19,16 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs from TTS.tts.models.vits import load_audio, wav_to_mel -from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results -from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram def id_to_torch(aux_id, cuda=False): @@ -89,7 +89,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset): @staticmethod def _compute_and_save_pitch(config, wav_file, pitch_file=None): wav, _ = load_audio(wav_file) - f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax) + f0 = compute_f0( + x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax + ) # skip the last F0 value to align with the spectrogram if wav.shape[1] % config.hop_length != 0: f0 = f0[:-1] @@ -632,7 +634,9 @@ class ForwardTTSE2e(BaseTTSE2E): figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) # Sample audio - encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio) + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.__mel_basis, **self.config.audio + ) audios[f"{name_prefix}/encoder_audio"] = encoder_audio # vocoder outputs @@ -780,7 +784,9 @@ class ForwardTTSE2e(BaseTTSE2E): outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}) # collect outputs - wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio) + wav = mel_to_wav_numpy( + mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio + ) alignments = outputs["alignments"] return_dict = { "wav": wav[None, :], diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index c0a639a2..2633b83c 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -1,10 +1,10 @@ from typing import Callable, Tuple + import librosa import numpy as np - -import soundfile as sf import pyworld as pw import scipy +import soundfile as sf # from TTS.tts.utils.helpers import StandardScaler @@ -148,21 +148,15 @@ def wav_to_mel(*, y: np.ndarray = None, **kwargs) -> np.ndarray: return S.astype(np.float32) -def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray: +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" S = spec.copy() - if denorm_func is not None: - S = denorm_func(spec=S, **kwargs) - S = db_to_amp(S) return griffin_lim(spec=S**power, **kwargs) -def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray: +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" S = mel.copy() - if denorm_func is not None: - S = denorm_func(spec=S, **kwargs) - S = db_to_amp(S) S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear return griffin_lim(spec=S**power, **kwargs) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index 2e415df1..21e7b234 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -1,6 +1,6 @@ +import librosa import torch from torch import nn -import librosa class TorchSTFT(nn.Module): # pylint: disable=abstract-method diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 98a0a939..ce2b56fb 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -3,9 +3,9 @@ from typing import Dict import numpy as np import torch from matplotlib import pyplot as plt -from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel from TTS.utils.audio.processor import AudioProcessor @@ -30,7 +30,13 @@ def interpolate_vocoder_input(scale_factor, spec): return spec -def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict: +def plot_results( + y_hat: torch.tensor, + y: torch.tensor, + ap: AudioProcessor = None, + audio_config: "Coqpit" = None, + name_prefix: str = None, +) -> Dict: """Plot the predicted and the real waveform and their spectrograms. Args: diff --git a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py index b8463cd1..754c1a08 100644 --- a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py +++ b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py @@ -8,7 +8,6 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio from TTS.tts.utils.text.tokenizer import TTSTokenizer - output_path = os.path.dirname(os.path.abspath(__file__)) # init configs