Make style

This commit is contained in:
Eren Gölge 2022-04-19 10:59:59 +02:00 committed by Eren G??lge
parent edd59c81e8
commit 9291d13c69
6 changed files with 27 additions and 23 deletions

View File

@ -11,7 +11,6 @@ from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")

View File

@ -2,10 +2,9 @@ import os
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple, Union
import numpy as np
import pyworld as pw
import torch
import torch.distributed as dist
from coqpit import Coqpit
@ -20,15 +19,16 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.models.base_tts import BaseTTSE2E
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
from TTS.tts.models.vits import load_audio, wav_to_mel
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy
from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
def id_to_torch(aux_id, cuda=False):
@ -89,7 +89,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
@staticmethod
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
wav, _ = load_audio(wav_file)
f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax)
f0 = compute_f0(
x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax
)
# skip the last F0 value to align with the spectrogram
if wav.shape[1] % config.hop_length != 0:
f0 = f0[:-1]
@ -632,7 +634,9 @@ class ForwardTTSE2e(BaseTTSE2E):
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
# Sample audio
encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio)
encoder_audio = mel_to_wav_numpy(
mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.__mel_basis, **self.config.audio
)
audios[f"{name_prefix}/encoder_audio"] = encoder_audio
# vocoder outputs
@ -780,7 +784,9 @@ class ForwardTTSE2e(BaseTTSE2E):
outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
# collect outputs
wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio)
wav = mel_to_wav_numpy(
mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio
)
alignments = outputs["alignments"]
return_dict = {
"wav": wav[None, :],

View File

@ -1,10 +1,10 @@
from typing import Callable, Tuple
import librosa
import numpy as np
import soundfile as sf
import pyworld as pw
import scipy
import soundfile as sf
# from TTS.tts.utils.helpers import StandardScaler
@ -148,21 +148,15 @@ def wav_to_mel(*, y: np.ndarray = None, **kwargs) -> np.ndarray:
return S.astype(np.float32)
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = spec.copy()
if denorm_func is not None:
S = denorm_func(spec=S, **kwargs)
S = db_to_amp(S)
return griffin_lim(spec=S**power, **kwargs)
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
S = mel.copy()
if denorm_func is not None:
S = denorm_func(spec=S, **kwargs)
S = db_to_amp(S)
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
return griffin_lim(spec=S**power, **kwargs)

View File

@ -1,6 +1,6 @@
import librosa
import torch
from torch import nn
import librosa
class TorchSTFT(nn.Module): # pylint: disable=abstract-method

View File

@ -3,9 +3,9 @@ from typing import Dict
import numpy as np
import torch
from matplotlib import pyplot as plt
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
from TTS.utils.audio.processor import AudioProcessor
@ -30,7 +30,13 @@ def interpolate_vocoder_input(scale_factor, spec):
return spec
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict:
def plot_results(
y_hat: torch.tensor,
y: torch.tensor,
ap: AudioProcessor = None,
audio_config: "Coqpit" = None,
name_prefix: str = None,
) -> Dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:

View File

@ -8,7 +8,6 @@ from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio
from TTS.tts.utils.text.tokenizer import TTSTokenizer
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs