Make style

This commit is contained in:
Eren Gölge 2021-12-10 07:53:10 +00:00
parent b4bb0ace70
commit 66b6e9bc99
18 changed files with 96 additions and 75 deletions

View File

@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
speaker_manager = None speaker_manager = None
# setup model # setup model
model = setup_model(c) model = setup_model(c)

View File

@ -19,6 +19,7 @@ def compute_phonemes(item):
return [] return []
return list(set(ph)) return list(set(ph))
def main(): def main():
# pylint: disable=W0601 # pylint: disable=W0601
global c global c

View File

@ -1,12 +1,13 @@
import os
import glob
import pathlib
import argparse import argparse
import glob
import multiprocessing import multiprocessing
import os
import pathlib
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
def remove_silence(filepath): def remove_silence(filepath):
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
@ -69,10 +70,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir" "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
) )
parser.add_argument("-f", "--force", parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
default=False,
action='store_true',
help='Force the replace of exists files')
parser.add_argument( parser.add_argument(
"-g", "-g",
"--glob", "--glob",

View File

@ -4,8 +4,8 @@ from TTS.config import load_config, register_config
from TTS.trainer import Trainer, TrainingArgs from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor

View File

@ -100,7 +100,7 @@ if args.vocoder_path is not None:
# load models # load models
synthesizer = Synthesizer( synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
) )
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1 use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1

View File

@ -2,11 +2,12 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
# import torchaudio # import torchaudio
from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module): class PreEmphasis(torch.nn.Module):
@ -126,16 +127,16 @@ class ResNetSpeakerEncoder(nn.Module):
n_mels=audio_config["num_mels"], n_mels=audio_config["num_mels"],
power=2.0, power=2.0,
use_mel=True, use_mel=True,
mel_norm=None mel_norm=None,
), ),
'''torchaudio.transforms.MelSpectrogram( """torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"], sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"], n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"], win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"], hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window, window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"], n_mels=audio_config["num_mels"],
),''' ),""",
) )
else: else:
self.torch_spec = None self.torch_spec = None

View File

@ -531,7 +531,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded, "waveform": wav_padded,
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
"pitch": pitch, "pitch": pitch,
"language_ids": language_ids "language_ids": language_ids,
} }
raise TypeError( raise TypeError(

View File

@ -588,7 +588,7 @@ class VitsGeneratorLoss(nn.Module):
@staticmethod @staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l return l
def forward( def forward(

View File

@ -2,11 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
def setup_model( def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
config,
speaker_manager: "SpeakerManager" = None,
language_manager: "LanguageManager" = None
):
print(" > Using model: {}".format(config.model)) print(" > Using model: {}".format(config.model))
# fetch the right model implementation. # fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None: if "base_model" in config and config["base_model"] is not None:
@ -35,7 +31,7 @@ def setup_model(
config.model_params.num_chars = num_chars config.model_params.num_chars = num_chars
if "model_args" in config: if "model_args" in config:
config.model_args.num_chars = num_chars config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else: else:
model = MyModel(config, speaker_manager=speaker_manager) model = MyModel(config, speaker_manager=speaker_manager)

View File

@ -12,8 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -150,7 +150,13 @@ class BaseTTS(BaseModel):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name] language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`. """Generic batch formatting for `TTSDataset`.
@ -337,14 +343,16 @@ class BaseTTS(BaseModel):
if config.compute_f0: if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP # sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers # Weighted samplers
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler" assert not (
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler" num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"
if sampler is None: if sampler is None:
if getattr(config, "use_language_weighted_sampler", False): if getattr(config, "use_language_weighted_sampler", False):
@ -354,7 +362,6 @@ class BaseTTS(BaseModel):
print(" > Using Language weighted sampler") print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items) sampler = get_speaker_weighted_sampler(dataset.items)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size, batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -4,6 +4,7 @@ from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
# import torchaudio # import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -420,8 +421,9 @@ class Vits(BaseTTS):
): ):
# TODO: change this with torchaudio Resample # TODO: change this with torchaudio Resample
raise RuntimeError( raise RuntimeError(
' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!' " [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
.format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]) self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
)
) )
# pylint: disable=W0101,W0105 # pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample( """ self.audio_transform = torchaudio.transforms.Resample(
@ -675,7 +677,6 @@ class Vits(BaseTTS):
) )
return outputs return outputs
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
""" """
Shapes: Shapes:

View File

@ -88,7 +88,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
spec_gain=1.0, spec_gain=1.0,
power=None, power=None,
use_htk=False, use_htk=False,
mel_norm="slaney" mel_norm="slaney",
): ):
super().__init__() super().__init__()
self.n_fft = n_fft self.n_fft = n_fft
@ -155,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def _build_mel_basis(self): def _build_mel_basis(self):
mel_basis = librosa.filters.mel( mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
) )
self.mel_basis = torch.from_numpy(mel_basis).float() self.mel_basis = torch.from_numpy(mel_basis).float()

View File

@ -7,8 +7,8 @@ import torch
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
@ -200,12 +200,7 @@ class Synthesizer(object):
self.ap.save_wav(wav, path, self.output_sample_rate) self.ap.save_wav(wav, path, self.output_sample_rate)
def tts( def tts(
self, self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
text: str,
speaker_idx: str = "",
language_idx: str = "",
speaker_wav=None,
style_wav=None
) -> List[int]: ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech. """🐸 TTS magic. Run all the models and generate speech.
@ -254,7 +249,9 @@ class Synthesizer(object):
# handle multi-lingaul # handle multi-lingaul
language_id = None language_id = None
if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None): if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_idx and isinstance(language_idx, str): if language_idx and isinstance(language_idx, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_idx] language_id = self.tts_model.language_manager.language_id_mapping[language_idx]

View File

@ -1,8 +1,9 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py # This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import wave
import webrtcvad
import contextlib
import collections import collections
import contextlib
import wave
import webrtcvad
def read_wave(path): def read_wave(path):
@ -37,7 +38,7 @@ class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
def __init__(self, _bytes, timestamp, duration): def __init__(self, _bytes, timestamp, duration):
self.bytes =_bytes self.bytes = _bytes
self.timestamp = timestamp self.timestamp = timestamp
self.duration = duration self.duration = duration
@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram
if voiced_frames: if voiced_frames:
yield b"".join([f.bytes for f in voiced_frames]) yield b"".join([f.bytes for f in voiced_frames])
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300): def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
vad = webrtcvad.Vad(int(aggressiveness)) vad = webrtcvad.Vad(int(aggressiveness))

View File

@ -7,15 +7,18 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
mailabs_path = '/home/julian/workspace/mailabs/**' mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path) dataset_paths = glob(mailabs_path)
dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths] dataset_config = [
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
for path in dataset_paths
]
audio_config = BaseAudioConfig( audio_config = BaseAudioConfig(
sample_rate=16000, sample_rate=16000,
@ -61,7 +64,7 @@ config = VitsConfig(
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True, compute_input_seq_cache=True,
print_step=25, print_step=25,
use_language_weighted_sampler= True, use_language_weighted_sampler=True,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True, sort_by_audio_len=True,
@ -69,21 +72,31 @@ config = VitsConfig(
max_seq_len=160000, max_seq_len=160000,
output_path=output_path, output_path=output_path,
datasets=dataset_config, datasets=dataset_config,
characters= { characters={
"pad": "_", "pad": "_",
"eos": "&", "eos": "&",
"bos": "*", "bos": "*",
"characters": "'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„", "characters": "'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
"punctuations": "'(),-.:;¿? ", "punctuations": "'(),-.:;¿? ",
"phonemes": None, "phonemes": None,
"unique": True "unique": True,
}, },
test_sentences=[ test_sentences=[
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'], [
["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'], "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'], "mary_ann",
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'], None,
] "en_US",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr_FR",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
],
) )
# init audio processor # init audio processor

View File

@ -26,3 +26,4 @@ unidic-lite==1.0.8
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0 fsspec>=2021.04.0
pyworld pyworld
webrtcvad

View File

@ -31,7 +31,7 @@ dataset_config_pt = BaseDatasetConfig(
class TestFindUniquePhonemes(unittest.TestCase): class TestFindUniquePhonemes(unittest.TestCase):
@staticmethod @staticmethod
def test_espeak_phonemes(): def test_espeak_phonemes():
# prepare the config # prepare the config
config = VitsConfig( config = VitsConfig(
batch_size=2, batch_size=2,
eval_batch_size=2, eval_batch_size=2,
@ -52,9 +52,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
config.save_json(config_path) config.save_json(config_path)
# run test # run test
run_cli( run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
)
@staticmethod @staticmethod
def test_no_espeak_phonemes(): def test_no_espeak_phonemes():
@ -79,6 +77,4 @@ class TestFindUniquePhonemes(unittest.TestCase):
config.save_json(config_path) config.save_json(config_path)
# run test # run test
run_cli( run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
)

View File

@ -1,9 +1,11 @@
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.utils.languages import get_language_weighted_sampler
import torch
import functools import functools
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler
# Fixing random state to avoid random fails # Fixing random state to avoid random fails
torch.manual_seed(0) torch.manual_seed(0)
@ -25,18 +27,19 @@ dataset_config_pt = BaseDatasetConfig(
# Adding the EN samples twice to create an unbalanced dataset # Adding the EN samples twice to create an unbalanced dataset
train_samples, eval_samples = load_tts_samples( train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt], [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
eval_split=True
) )
def is_balanced(lang_1, lang_2): def is_balanced(lang_1, lang_2):
return 0.85 < lang_1/lang_2 < 1.2 return 0.85 < lang_1 / lang_2 < 1.2
random_sampler = torch.utils.data.RandomSampler(train_samples) random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0 en, pt = 0, 0
for index in ids: for index in ids:
if train_samples[index][3] == 'en': if train_samples[index][3] == "en":
en += 1 en += 1
else: else:
pt += 1 pt += 1
@ -47,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0 en, pt = 0, 0
for index in ids: for index in ids:
if train_samples[index][3] == 'en': if train_samples[index][3] == "en":
en += 1 en += 1
else: else:
pt += 1 pt += 1