Make style

This commit is contained in:
Eren Gölge 2021-12-10 07:53:10 +00:00
parent 54b7fb4e4a
commit 704dddcffa
18 changed files with 96 additions and 75 deletions

View File

@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
speaker_manager = None
# setup model
model = setup_model(c)

View File

@ -19,6 +19,7 @@ def compute_phonemes(item):
return []
return list(set(ph))
def main():
# pylint: disable=W0601
global c

View File

@ -1,12 +1,13 @@
import os
import glob
import pathlib
import argparse
import glob
import multiprocessing
import os
import pathlib
from tqdm.contrib.concurrent import process_map
from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
def remove_silence(filepath):
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
@ -69,10 +70,7 @@ if __name__ == "__main__":
parser.add_argument(
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
)
parser.add_argument("-f", "--force",
default=False,
action='store_true',
help='Force the replace of exists files')
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
parser.add_argument(
"-g",
"--glob",

View File

@ -4,8 +4,8 @@ from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor

View File

@ -100,7 +100,7 @@ if args.vocoder_path is not None:
# load models
synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
)
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1

View File

@ -2,11 +2,12 @@ import numpy as np
import torch
from torch import nn
from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
# import torchaudio
from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module):
@ -126,16 +127,16 @@ class ResNetSpeakerEncoder(nn.Module):
n_mels=audio_config["num_mels"],
power=2.0,
use_mel=True,
mel_norm=None
mel_norm=None,
),
'''torchaudio.transforms.MelSpectrogram(
"""torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),'''
),""",
)
else:
self.torch_spec = None

View File

@ -531,7 +531,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"language_ids": language_ids
"language_ids": language_ids,
}
raise TypeError(

View File

@ -588,7 +588,7 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l
def forward(

View File

@ -2,11 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
def setup_model(
config,
speaker_manager: "SpeakerManager" = None,
language_manager: "LanguageManager" = None
):
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
@ -35,7 +31,7 @@ def setup_model(
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)

View File

@ -12,8 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -150,7 +150,13 @@ class BaseTTS(BaseModel):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
@ -337,14 +343,16 @@ class BaseTTS(BaseModel):
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler"
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
@ -354,7 +362,6 @@ class BaseTTS(BaseModel):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -4,6 +4,7 @@ from itertools import chain
from typing import Dict, List, Tuple
import torch
# import torchaudio
from coqpit import Coqpit
from torch import nn
@ -420,8 +421,9 @@ class Vits(BaseTTS):
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!'
.format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
)
)
# pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample(
@ -675,7 +677,6 @@ class Vits(BaseTTS):
)
return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
"""
Shapes:

View File

@ -88,7 +88,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney"
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
@ -155,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()

View File

@ -7,8 +7,8 @@ import torch
from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
@ -200,12 +200,7 @@ class Synthesizer(object):
self.ap.save_wav(wav, path, self.output_sample_rate)
def tts(
self,
text: str,
speaker_idx: str = "",
language_idx: str = "",
speaker_wav=None,
style_wav=None
self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -254,7 +249,9 @@ class Synthesizer(object):
# handle multi-lingaul
language_id = None
if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None):
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_idx and isinstance(language_idx, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]

View File

@ -1,8 +1,9 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import wave
import webrtcvad
import contextlib
import collections
import contextlib
import wave
import webrtcvad
def read_wave(path):
@ -37,7 +38,7 @@ class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, _bytes, timestamp, duration):
self.bytes =_bytes
self.bytes = _bytes
self.timestamp = timestamp
self.duration = duration
@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram
if voiced_frames:
yield b"".join([f.bytes for f in voiced_frames])
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
vad = webrtcvad.Vad(int(aggressiveness))

View File

@ -7,15 +7,18 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
mailabs_path = '/home/julian/workspace/mailabs/**'
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths]
dataset_config = [
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
for path in dataset_paths
]
audio_config = BaseAudioConfig(
sample_rate=16000,
@ -61,7 +64,7 @@ config = VitsConfig(
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler= True,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
@ -69,21 +72,31 @@ config = VitsConfig(
max_seq_len=160000,
output_path=output_path,
datasets=dataset_config,
characters= {
characters={
"pad": "_",
"eos": "&",
"bos": "*",
"characters": "'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
"punctuations": "'(),-.:;¿? ",
"phonemes": None,
"unique": True
"unique": True,
},
test_sentences=[
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'],
["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'],
]
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en_US",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr_FR",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
],
)
# init audio processor

View File

@ -26,3 +26,4 @@ unidic-lite==1.0.8
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0
pyworld
webrtcvad

View File

@ -31,7 +31,7 @@ dataset_config_pt = BaseDatasetConfig(
class TestFindUniquePhonemes(unittest.TestCase):
@staticmethod
def test_espeak_phonemes():
# prepare the config
# prepare the config
config = VitsConfig(
batch_size=2,
eval_batch_size=2,
@ -52,9 +52,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
config.save_json(config_path)
# run test
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
)
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
@staticmethod
def test_no_espeak_phonemes():
@ -79,6 +77,4 @@ class TestFindUniquePhonemes(unittest.TestCase):
config.save_json(config_path)
# run test
run_cli(
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
)
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')

View File

@ -1,9 +1,11 @@
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.utils.languages import get_language_weighted_sampler
import torch
import functools
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -25,18 +27,19 @@ dataset_config_pt = BaseDatasetConfig(
# Adding the EN samples twice to create an unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt],
eval_split=True
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)
def is_balanced(lang_1, lang_2):
return 0.85 < lang_1/lang_2 < 1.2
return 0.85 < lang_1 / lang_2 < 1.2
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == 'en':
if train_samples[index][3] == "en":
en += 1
else:
pt += 1
@ -47,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == 'en':
if train_samples[index][3] == "en":
en += 1
else:
pt += 1