mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
54b7fb4e4a
commit
704dddcffa
|
@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
speaker_manager = None
|
||||
|
||||
|
||||
# setup model
|
||||
model = setup_model(c)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ def compute_phonemes(item):
|
|||
return []
|
||||
return list(set(ph))
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=W0601
|
||||
global c
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
import glob
|
||||
import pathlib
|
||||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments
|
||||
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
|
||||
|
||||
|
||||
def remove_silence(filepath):
|
||||
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
|
@ -69,10 +70,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||
)
|
||||
parser.add_argument("-f", "--force",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Force the replace of exists files')
|
||||
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--glob",
|
||||
|
|
|
@ -4,8 +4,8 @@ from TTS.config import load_config, register_config
|
|||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ if args.vocoder_path is not None:
|
|||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||
model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||
|
|
|
@ -2,11 +2,12 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
# import torchaudio
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class PreEmphasis(torch.nn.Module):
|
||||
|
@ -126,16 +127,16 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
n_mels=audio_config["num_mels"],
|
||||
power=2.0,
|
||||
use_mel=True,
|
||||
mel_norm=None
|
||||
mel_norm=None,
|
||||
),
|
||||
'''torchaudio.transforms.MelSpectrogram(
|
||||
"""torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),'''
|
||||
),""",
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
|
|
@ -531,7 +531,7 @@ class TTSDataset(Dataset):
|
|||
"waveform": wav_padded,
|
||||
"raw_text": batch["raw_text"],
|
||||
"pitch": pitch,
|
||||
"language_ids": language_ids
|
||||
"language_ids": language_ids,
|
||||
}
|
||||
|
||||
raise TypeError(
|
||||
|
|
|
@ -588,7 +588,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||
return l
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -2,11 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
|||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_model(
|
||||
config,
|
||||
speaker_manager: "SpeakerManager" = None,
|
||||
language_manager: "LanguageManager" = None
|
||||
):
|
||||
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
|
@ -35,7 +31,7 @@ def setup_model(
|
|||
config.model_params.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
config.model_args.num_chars = num_chars
|
||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||
else:
|
||||
model = MyModel(config, speaker_manager=speaker_manager)
|
||||
|
|
|
@ -12,8 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.model import BaseModel
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -150,7 +150,13 @@ class BaseTTS(BaseModel):
|
|||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
@ -337,14 +343,16 @@ class BaseTTS(BaseModel):
|
|||
if config.compute_f0:
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
|
||||
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# Weighted samplers
|
||||
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
|
@ -354,7 +362,6 @@ class BaseTTS(BaseModel):
|
|||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
|
|
|
@ -4,6 +4,7 @@ from itertools import chain
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -420,8 +421,9 @@ class Vits(BaseTTS):
|
|||
):
|
||||
# TODO: change this with torchaudio Resample
|
||||
raise RuntimeError(
|
||||
' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!'
|
||||
.format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
|
||||
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
||||
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
|
||||
)
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
||||
|
@ -675,7 +677,6 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
"""
|
||||
Shapes:
|
||||
|
|
|
@ -88,7 +88,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney"
|
||||
mel_norm="slaney",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
|
@ -155,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
|
||||
self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ import torch
|
|||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
|
@ -200,12 +200,7 @@ class Synthesizer(object):
|
|||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text: str,
|
||||
speaker_idx: str = "",
|
||||
language_idx: str = "",
|
||||
speaker_wav=None,
|
||||
style_wav=None
|
||||
self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
|
@ -254,7 +249,9 @@ class Synthesizer(object):
|
|||
|
||||
# handle multi-lingaul
|
||||
language_id = None
|
||||
if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None):
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||
):
|
||||
if language_idx and isinstance(language_idx, str):
|
||||
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||
import wave
|
||||
import webrtcvad
|
||||
import contextlib
|
||||
import collections
|
||||
import contextlib
|
||||
import wave
|
||||
|
||||
import webrtcvad
|
||||
|
||||
|
||||
def read_wave(path):
|
||||
|
@ -37,7 +38,7 @@ class Frame(object):
|
|||
"""Represents a "frame" of audio data."""
|
||||
|
||||
def __init__(self, _bytes, timestamp, duration):
|
||||
self.bytes =_bytes
|
||||
self.bytes = _bytes
|
||||
self.timestamp = timestamp
|
||||
self.duration = duration
|
||||
|
||||
|
@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram
|
|||
if voiced_frames:
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
|
||||
|
||||
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
||||
|
||||
vad = webrtcvad.Vad(int(aggressiveness))
|
||||
|
|
|
@ -7,15 +7,18 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
|||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
mailabs_path = '/home/julian/workspace/mailabs/**'
|
||||
mailabs_path = "/home/julian/workspace/mailabs/**"
|
||||
dataset_paths = glob(mailabs_path)
|
||||
dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths]
|
||||
dataset_config = [
|
||||
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
||||
for path in dataset_paths
|
||||
]
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=16000,
|
||||
|
@ -61,7 +64,7 @@ config = VitsConfig(
|
|||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
compute_input_seq_cache=True,
|
||||
print_step=25,
|
||||
use_language_weighted_sampler= True,
|
||||
use_language_weighted_sampler=True,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
|
@ -69,21 +72,31 @@ config = VitsConfig(
|
|||
max_seq_len=160000,
|
||||
output_path=output_path,
|
||||
datasets=dataset_config,
|
||||
characters= {
|
||||
characters={
|
||||
"pad": "_",
|
||||
"eos": "&",
|
||||
"bos": "*",
|
||||
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
|
||||
"punctuations": "!¡'(),-.:;¿? ",
|
||||
"phonemes": None,
|
||||
"unique": True
|
||||
"unique": True,
|
||||
},
|
||||
test_sentences=[
|
||||
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'],
|
||||
["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'],
|
||||
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'],
|
||||
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'],
|
||||
]
|
||||
[
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"mary_ann",
|
||||
None,
|
||||
"en_US",
|
||||
],
|
||||
[
|
||||
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
|
||||
"ezwa",
|
||||
None,
|
||||
"fr_FR",
|
||||
],
|
||||
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
|
||||
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
|
||||
],
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
|
|
|
@ -26,3 +26,4 @@ unidic-lite==1.0.8
|
|||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||
fsspec>=2021.04.0
|
||||
pyworld
|
||||
webrtcvad
|
||||
|
|
|
@ -31,7 +31,7 @@ dataset_config_pt = BaseDatasetConfig(
|
|||
class TestFindUniquePhonemes(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_espeak_phonemes():
|
||||
# prepare the config
|
||||
# prepare the config
|
||||
config = VitsConfig(
|
||||
batch_size=2,
|
||||
eval_batch_size=2,
|
||||
|
@ -52,9 +52,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
|
|||
config.save_json(config_path)
|
||||
|
||||
# run test
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
|
||||
)
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||
|
||||
@staticmethod
|
||||
def test_no_espeak_phonemes():
|
||||
|
@ -79,6 +77,4 @@ class TestFindUniquePhonemes(unittest.TestCase):
|
|||
config.save_json(config_path)
|
||||
|
||||
# run test
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
|
||||
)
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||
import torch
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
@ -25,18 +27,19 @@ dataset_config_pt = BaseDatasetConfig(
|
|||
|
||||
# Adding the EN samples twice to create an unbalanced dataset
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
[dataset_config_en, dataset_config_en, dataset_config_pt],
|
||||
eval_split=True
|
||||
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||
)
|
||||
|
||||
|
||||
def is_balanced(lang_1, lang_2):
|
||||
return 0.85 < lang_1/lang_2 < 1.2
|
||||
return 0.85 < lang_1 / lang_2 < 1.2
|
||||
|
||||
|
||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index][3] == 'en':
|
||||
if train_samples[index][3] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
@ -47,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples)
|
|||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index][3] == 'en':
|
||||
if train_samples[index][3] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
|
Loading…
Reference in New Issue