mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
edd59c81e8
commit
9291d13c69
|
@ -11,7 +11,6 @@ from torch.utils.data import Dataset
|
||||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
|
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
|
||||||
|
|
||||||
|
|
||||||
# to prevent too many open files error as suggested here
|
# to prevent too many open files error as suggested here
|
||||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
|
@ -2,10 +2,9 @@ import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyworld as pw
|
import pyworld as pw
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -20,15 +19,16 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||||
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
|
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
|
||||||
from TTS.tts.models.vits import load_audio, wav_to_mel
|
from TTS.tts.models.vits import load_audio, wav_to_mel
|
||||||
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy
|
|
||||||
from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||||
|
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||||
|
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||||
|
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
|
||||||
|
|
||||||
|
|
||||||
def id_to_torch(aux_id, cuda=False):
|
def id_to_torch(aux_id, cuda=False):
|
||||||
|
@ -89,7 +89,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
|
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
|
||||||
wav, _ = load_audio(wav_file)
|
wav, _ = load_audio(wav_file)
|
||||||
f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax)
|
f0 = compute_f0(
|
||||||
|
x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax
|
||||||
|
)
|
||||||
# skip the last F0 value to align with the spectrogram
|
# skip the last F0 value to align with the spectrogram
|
||||||
if wav.shape[1] % config.hop_length != 0:
|
if wav.shape[1] % config.hop_length != 0:
|
||||||
f0 = f0[:-1]
|
f0 = f0[:-1]
|
||||||
|
@ -632,7 +634,9 @@ class ForwardTTSE2e(BaseTTSE2E):
|
||||||
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio)
|
encoder_audio = mel_to_wav_numpy(
|
||||||
|
mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.__mel_basis, **self.config.audio
|
||||||
|
)
|
||||||
audios[f"{name_prefix}/encoder_audio"] = encoder_audio
|
audios[f"{name_prefix}/encoder_audio"] = encoder_audio
|
||||||
|
|
||||||
# vocoder outputs
|
# vocoder outputs
|
||||||
|
@ -780,7 +784,9 @@ class ForwardTTSE2e(BaseTTSE2E):
|
||||||
outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
|
||||||
|
|
||||||
# collect outputs
|
# collect outputs
|
||||||
wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio)
|
wav = mel_to_wav_numpy(
|
||||||
|
mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio
|
||||||
|
)
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
return_dict = {
|
return_dict = {
|
||||||
"wav": wav[None, :],
|
"wav": wav[None, :],
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import soundfile as sf
|
|
||||||
import pyworld as pw
|
import pyworld as pw
|
||||||
import scipy
|
import scipy
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
# from TTS.tts.utils.helpers import StandardScaler
|
# from TTS.tts.utils.helpers import StandardScaler
|
||||||
|
|
||||||
|
@ -148,21 +148,15 @@ def wav_to_mel(*, y: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||||
return S.astype(np.float32)
|
return S.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
|
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
S = spec.copy()
|
S = spec.copy()
|
||||||
if denorm_func is not None:
|
|
||||||
S = denorm_func(spec=S, **kwargs)
|
|
||||||
S = db_to_amp(S)
|
|
||||||
return griffin_lim(spec=S**power, **kwargs)
|
return griffin_lim(spec=S**power, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, denorm_func: Callable = None, **kwargs) -> np.ndarray:
|
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
S = mel.copy()
|
S = mel.copy()
|
||||||
if denorm_func is not None:
|
|
||||||
S = denorm_func(spec=S, **kwargs)
|
|
||||||
S = db_to_amp(S)
|
|
||||||
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||||
return griffin_lim(spec=S**power, **kwargs)
|
return griffin_lim(spec=S**power, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import librosa
|
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
|
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
|
from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel
|
||||||
from TTS.utils.audio.processor import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,13 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor=None, audio_config: "Coqpit"= None, name_prefix: str = None) -> Dict:
|
def plot_results(
|
||||||
|
y_hat: torch.tensor,
|
||||||
|
y: torch.tensor,
|
||||||
|
ap: AudioProcessor = None,
|
||||||
|
audio_config: "Coqpit" = None,
|
||||||
|
name_prefix: str = None,
|
||||||
|
) -> Dict:
|
||||||
"""Plot the predicted and the real waveform and their spectrograms.
|
"""Plot the predicted and the real waveform and their spectrograms.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -8,7 +8,6 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio
|
from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# init configs
|
# init configs
|
||||||
|
|
Loading…
Reference in New Issue