diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 014ba4e8..7b489fd6 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name else: speaker_manager = None - # setup model model = setup_model(c) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 832ef082..d3143ca3 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -19,6 +19,7 @@ def compute_phonemes(item): return [] return list(set(ph)) + def main(): # pylint: disable=W0601 global c diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index a32f0f45..9070f2da 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,12 +1,13 @@ -import os -import glob -import pathlib import argparse +import glob import multiprocessing +import os +import pathlib from tqdm.contrib.concurrent import process_map -from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments +from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave + def remove_silence(filepath): output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) @@ -69,10 +70,7 @@ if __name__ == "__main__": parser.add_argument( "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir" ) - parser.add_argument("-f", "--force", - default=False, - action='store_true', - help='Force the replace of exists files') + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") parser.add_argument( "-g", "--glob", diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 5330649a..191cba00 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -4,8 +4,8 @@ from TTS.config import load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model -from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor diff --git a/TTS/server/server.py b/TTS/server/server.py index c6d67141..f7bc79c4 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -100,7 +100,7 @@ if args.vocoder_path is not None: # load models synthesizer = Synthesizer( - model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda + model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda ) use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1 diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 8f0a8809..7bd507fb 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -2,11 +2,12 @@ import numpy as np import torch from torch import nn +from TTS.utils.audio import TorchSTFT +from TTS.utils.io import load_fsspec + # import torchaudio -from TTS.utils.audio import TorchSTFT -from TTS.utils.io import load_fsspec class PreEmphasis(torch.nn.Module): @@ -126,16 +127,16 @@ class ResNetSpeakerEncoder(nn.Module): n_mels=audio_config["num_mels"], power=2.0, use_mel=True, - mel_norm=None + mel_norm=None, ), - '''torchaudio.transforms.MelSpectrogram( + """torchaudio.transforms.MelSpectrogram( sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"], - ),''' + ),""", ) else: self.torch_spec = None diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 000393ea..843cea58 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -531,7 +531,7 @@ class TTSDataset(Dataset): "waveform": wav_padded, "raw_text": batch["raw_text"], "pitch": pitch, - "language_ids": language_ids + "language_ids": language_ids, } raise TypeError( diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 9c219998..7de45041 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -588,7 +588,7 @@ class VitsGeneratorLoss(nn.Module): @staticmethod def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): - l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() return l def forward( diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index acd89110..4cc8b658 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -2,11 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model( - config, - speaker_manager: "SpeakerManager" = None, - language_manager: "LanguageManager" = None - ): +def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None): print(" > Using model: {}".format(config.model)) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: @@ -35,7 +31,7 @@ def setup_model( config.model_params.num_chars = num_chars if "model_args" in config: config.model_args.num_chars = num_chars - if config.model.lower() in ["vits"]: # If model supports multiple languages + if config.model.lower() in ["vits"]: # If model supports multiple languages model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager) else: model = MyModel(config, speaker_manager=speaker_manager) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 1f92bfc7..e52cd765 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,8 +12,8 @@ from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -150,7 +150,13 @@ class BaseTTS(BaseModel): if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: language_id = self.language_manager.language_id_mapping[language_name] - return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. @@ -337,14 +343,16 @@ class BaseTTS(BaseModel): if config.compute_f0: dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) - - # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None # Weighted samplers - assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler" - assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) + ), "language_weighted_sampler is not supported with DistributedSampler" + assert not ( + num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) + ), "speaker_weighted_sampler is not supported with DistributedSampler" if sampler is None: if getattr(config, "use_language_weighted_sampler", False): @@ -354,7 +362,6 @@ class BaseTTS(BaseModel): print(" > Using Language weighted sampler") sampler = get_speaker_weighted_sampler(dataset.items) - loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ca110eb0..5b4725b3 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -4,6 +4,7 @@ from itertools import chain from typing import Dict, List, Tuple import torch + # import torchaudio from coqpit import Coqpit from torch import nn @@ -420,8 +421,9 @@ class Vits(BaseTTS): ): # TODO: change this with torchaudio Resample raise RuntimeError( - ' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!' - .format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]) + " [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format( + self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"] + ) ) # pylint: disable=W0101,W0105 """ self.audio_transform = torchaudio.transforms.Resample( @@ -675,7 +677,6 @@ class Vits(BaseTTS): ) return outputs - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): """ Shapes: diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 10c9ec7e..d01196c4 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -88,7 +88,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method spec_gain=1.0, power=None, use_htk=False, - mel_norm="slaney" + mel_norm="slaney", ): super().__init__() self.n_fft = n_fft @@ -155,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method def _build_mel_basis(self): mel_basis = librosa.filters.mel( - self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, ) self.mel_basis = torch.from_numpy(mel_basis).float() diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index e6df6561..d64c0936 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -7,8 +7,8 @@ import torch from TTS.config import load_config from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -200,12 +200,7 @@ class Synthesizer(object): self.ap.save_wav(wav, path, self.output_sample_rate) def tts( - self, - text: str, - speaker_idx: str = "", - language_idx: str = "", - speaker_wav=None, - style_wav=None + self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -254,7 +249,9 @@ class Synthesizer(object): # handle multi-lingaul language_id = None - if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None): + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None + ): if language_idx and isinstance(language_idx, str): language_id = self.tts_model.language_manager.language_id_mapping[language_idx] diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index 33548087..923544d0 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,8 +1,9 @@ # This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py -import wave -import webrtcvad -import contextlib import collections +import contextlib +import wave + +import webrtcvad def read_wave(path): @@ -37,7 +38,7 @@ class Frame(object): """Represents a "frame" of audio data.""" def __init__(self, _bytes, timestamp, duration): - self.bytes =_bytes + self.bytes = _bytes self.timestamp = timestamp self.duration = duration @@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram if voiced_frames: yield b"".join([f.bytes for f in voiced_frames]) + def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300): vad = webrtcvad.Vad(int(aggressiveness)) diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 6beaef38..be4747df 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -7,15 +7,18 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits, VitsArgs -from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) -mailabs_path = '/home/julian/workspace/mailabs/**' +mailabs_path = "/home/julian/workspace/mailabs/**" dataset_paths = glob(mailabs_path) -dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths] +dataset_config = [ + BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1]) + for path in dataset_paths +] audio_config = BaseAudioConfig( sample_rate=16000, @@ -61,7 +64,7 @@ config = VitsConfig( phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), compute_input_seq_cache=True, print_step=25, - use_language_weighted_sampler= True, + use_language_weighted_sampler=True, print_eval=False, mixed_precision=False, sort_by_audio_len=True, @@ -69,21 +72,31 @@ config = VitsConfig( max_seq_len=160000, output_path=output_path, datasets=dataset_config, - characters= { + characters={ "pad": "_", "eos": "&", "bos": "*", "characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„", "punctuations": "!¡'(),-.:;¿? ", "phonemes": None, - "unique": True + "unique": True, }, test_sentences=[ - ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'], - ["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'], - ["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'], - ["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'], - ] + [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "mary_ann", + None, + "en_US", + ], + [ + "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", + "ezwa", + None, + "fr_FR", + ], + ["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"], + ["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"], + ], ) # init audio processor diff --git a/requirements.txt b/requirements.txt index 3ec33ceb..453c3ec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ unidic-lite==1.0.8 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0 fsspec>=2021.04.0 pyworld +webrtcvad diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py index 33fad9ba..fa0abe4b 100644 --- a/tests/aux_tests/test_find_unique_phonemes.py +++ b/tests/aux_tests/test_find_unique_phonemes.py @@ -31,7 +31,7 @@ dataset_config_pt = BaseDatasetConfig( class TestFindUniquePhonemes(unittest.TestCase): @staticmethod def test_espeak_phonemes(): - # prepare the config + # prepare the config config = VitsConfig( batch_size=2, eval_batch_size=2, @@ -52,9 +52,7 @@ class TestFindUniquePhonemes(unittest.TestCase): config.save_json(config_path) # run test - run_cli( - f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"' - ) + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"') @staticmethod def test_no_espeak_phonemes(): @@ -79,6 +77,4 @@ class TestFindUniquePhonemes(unittest.TestCase): config.save_json(config_path) # run test - run_cli( - f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"' - ) + run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"') diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 5e4e4151..3d8d6c75 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,9 +1,11 @@ -from TTS.tts.datasets import load_tts_samples -from TTS.config.shared_configs import BaseDatasetConfig -from TTS.tts.utils.languages import get_language_weighted_sampler -import torch import functools +import torch + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.languages import get_language_weighted_sampler + # Fixing random state to avoid random fails torch.manual_seed(0) @@ -25,18 +27,19 @@ dataset_config_pt = BaseDatasetConfig( # Adding the EN samples twice to create an unbalanced dataset train_samples, eval_samples = load_tts_samples( - [dataset_config_en, dataset_config_en, dataset_config_pt], - eval_split=True + [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True ) + def is_balanced(lang_1, lang_2): - return 0.85 < lang_1/lang_2 < 1.2 + return 0.85 < lang_1 / lang_2 < 1.2 + random_sampler = torch.utils.data.RandomSampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == 'en': + if train_samples[index][3] == "en": en += 1 else: pt += 1 @@ -47,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == 'en': + if train_samples[index][3] == "en": en += 1 else: pt += 1