From d309f50e53aaa4fa6fc540f98615f1963d61447f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 25 Mar 2023 18:33:23 +0100 Subject: [PATCH] Implement FreeVC (#2451) * Update .gitignore * Draft FreeVC implementation * Tests and relevant updates * Update API tests * Add missings * Update requirements * :( * Lazy handle for vc * Update docs for voice conversion * Make style --- .gitignore | 3 +- TTS/.models.json | 13 + TTS/api.py | 89 +- TTS/bin/synthesize.py | 80 +- TTS/config/__init__.py | 2 +- TTS/tts/models/base_tts.py | 2 + TTS/utils/generic_utils.py | 1 + TTS/utils/manage.py | 8 + TTS/utils/synthesizer.py | 54 +- TTS/vc/configs/__init__.py | 0 TTS/vc/configs/freevc_config.py | 5 + TTS/vc/configs/shared_configs.py | 155 ++++ TTS/vc/models/__init__.py | 17 + TTS/vc/models/base_vc.py | 429 +++++++++ TTS/vc/models/freevc.py | 833 ++++++++++++++++++ TTS/vc/modules/__init__.py | 0 TTS/vc/modules/freevc/__init__.py | 0 TTS/vc/modules/freevc/commons.py | 170 ++++ TTS/vc/modules/freevc/mel_processing.py | 125 +++ TTS/vc/modules/freevc/modules.py | 391 ++++++++ .../freevc/speaker_encoder/__init__.py | 0 .../modules/freevc/speaker_encoder/audio.py | 65 ++ .../modules/freevc/speaker_encoder/hparams.py | 31 + .../freevc/speaker_encoder/speaker_encoder.py | 175 ++++ TTS/vc/modules/freevc/wavlm/__init__.py | 35 + TTS/vc/modules/freevc/wavlm/config.json | 99 +++ TTS/vc/modules/freevc/wavlm/modules.py | 768 ++++++++++++++++ TTS/vc/modules/freevc/wavlm/wavlm.py | 719 +++++++++++++++ TTS/vocoder/models/base_vocoder.py | 2 + docs/source/inference.md | 36 +- requirements.txt | 1 + tests/inference_tests/test_python_api.py | 4 +- tests/vc_tests/__init__.py | 0 tests/vc_tests/test_freevc.py | 135 +++ tests/zoo_tests/test_models.py | 7 + 35 files changed, 4416 insertions(+), 38 deletions(-) create mode 100644 TTS/vc/configs/__init__.py create mode 100644 TTS/vc/configs/freevc_config.py create mode 100644 TTS/vc/configs/shared_configs.py create mode 100644 TTS/vc/models/__init__.py create mode 100644 TTS/vc/models/base_vc.py create mode 100644 TTS/vc/models/freevc.py create mode 100644 TTS/vc/modules/__init__.py create mode 100644 TTS/vc/modules/freevc/__init__.py create mode 100644 TTS/vc/modules/freevc/commons.py create mode 100644 TTS/vc/modules/freevc/mel_processing.py create mode 100644 TTS/vc/modules/freevc/modules.py create mode 100644 TTS/vc/modules/freevc/speaker_encoder/__init__.py create mode 100644 TTS/vc/modules/freevc/speaker_encoder/audio.py create mode 100644 TTS/vc/modules/freevc/speaker_encoder/hparams.py create mode 100644 TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py create mode 100644 TTS/vc/modules/freevc/wavlm/__init__.py create mode 100644 TTS/vc/modules/freevc/wavlm/config.json create mode 100644 TTS/vc/modules/freevc/wavlm/modules.py create mode 100644 TTS/vc/modules/freevc/wavlm/wavlm.py create mode 100644 tests/vc_tests/__init__.py create mode 100644 tests/vc_tests/test_freevc.py diff --git a/.gitignore b/.gitignore index 7ce70e40..563040e8 100644 --- a/.gitignore +++ b/.gitignore @@ -137,7 +137,7 @@ VCTK-Corpus-removed-silence/* # ignore training logs trainer_*_log.txt -# files used internally fro dev, test etc. +# files used internally for dev, test etc. tests/outputs/* tests/train_outputs/* TODO.txt @@ -168,3 +168,4 @@ internal/* wandb depot/* coqui_recipes/* +local_scripts/* diff --git a/TTS/.models.json b/TTS/.models.json index 7050a9bc..02b95cf2 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -802,5 +802,18 @@ } } } + }, + "voice_conversion_models":{ + "multilingual":{ + "vctk":{ + "freevc24":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip", + "description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC", + "author": "Jing-Yi Li @OlaWod", + "license": "MIT", + "commit": null + } + } + } } } diff --git a/TTS/api.py b/TTS/api.py index 0e694263..008fd082 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,5 +1,7 @@ +import tempfile from pathlib import Path +from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer @@ -49,11 +51,14 @@ class TTS: gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.synthesizer = None + self.voice_converter = None + if model_name: - self.load_model_by_name(model_name, gpu) + self.load_tts_model_by_name(model_name, gpu) if model_path: - self.load_model_by_path( + self.load_tts_model_by_path( model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu ) @@ -96,12 +101,22 @@ class TTS: def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) - if model_item["default_vocoder"] is None: + if model_item.get("default_vocoder") is None: return model_path, config_path, None, None vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) return model_path, config_path, vocoder_path, vocoder_config_path - def load_model_by_name(self, model_name: str, gpu: bool = False): + def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the voice conversion models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + model_path, config_path, _, _ = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + + def load_tts_model_by_name(self, model_name: str, gpu: bool = False): """Load one of 🐸TTS models by name. Args: @@ -127,7 +142,7 @@ class TTS: use_cuda=gpu, ) - def load_model_by_path( + def load_tts_model_by_path( self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False ): """Load a model from a path. @@ -219,3 +234,67 @@ class TTS: """ wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) self.synthesizer.save_wav(wav=wav, path=file_path) + + def voice_conversion( + self, + sourve_wav: str, + target_wav: str, + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args: + source_wav (str): + Path to the source wav file. + target_wav (str): + Path to the target wav file. + """ + wav = self.synthesizer.voice_conversion(source_wav=sourve_wav, target_wav=target_wav) + return wav + + def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None): + """Convert text to speech with voice conversion. + + It combines tts with voice conversion to fake voice cloning. + + - Convert text to speech with tts. + - Convert the output wav to target speaker with voice conversion. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + """ + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + # Lazy code... save it to a temp file to resample it while reading it for VC + self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name) + if self.voice_converter is None: + self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") + wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) + return wav + + def tts_with_vc_to_file( + self, text: str, language: str = None, speaker_wav: str = None, file_path: str = "output.wav" + ): + """Convert text to speech with voice conversion and save to file. + + Check `tts_with_vc` for more details. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + file_path (str, optional): + Output file path. Defaults to "output.wav". + """ + wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index bbcb9c95..2877ea2b 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -100,6 +100,12 @@ If you don't specify any models, then it uses LJSpeech based English model. ``` $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx ``` + +### Voice Conversion Models + + ``` + $ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav + ``` """ # We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep # documentation in sync more easily. @@ -245,6 +251,20 @@ If you don't specify any models, then it uses LJSpeech based English model. default=True, ) + # voice conversion args + parser.add_argument( + "--source_wav", + type=str, + default=None, + help="Original audio file to convert in the voice of the target_wav", + ) + parser.add_argument( + "--target_wav", + type=str, + default=None, + help="Target audio file to convert in the voice of the source_wav", + ) + args = parser.parse_args() # print the description if either text or list_models is not set @@ -256,6 +276,8 @@ If you don't specify any models, then it uses LJSpeech based English model. args.reference_wav, args.model_info_by_idx, args.model_info_by_name, + args.source_wav, + args.target_wav, ] if not any(check_args): parser.parse_args(["-h"]) @@ -264,21 +286,23 @@ If you don't specify any models, then it uses LJSpeech based English model. path = Path(__file__).parent / "../.models.json" manager = ModelManager(path, progress_bar=args.progress_bar) - model_path = None - config_path = None + tts_path = None + tts_config_path = None speakers_file_path = None language_ids_file_path = None vocoder_path = None vocoder_config_path = None encoder_path = None encoder_config_path = None + vc_path = None + vc_config_path = None # CASE1 #list : list pre-trained TTS models if args.list_models: manager.list_models() sys.exit() - # CASE2 #info : model info of pre-trained TTS models + # CASE2 #info : model info for pre-trained TTS models if args.model_info_by_idx: model_query = args.model_info_by_idx manager.model_info_by_idx(model_query) @@ -292,15 +316,27 @@ If you don't specify any models, then it uses LJSpeech based English model. # CASE3: load pre-trained model paths if args.model_name is not None and not args.model_path: model_path, config_path, model_item = manager.download_model(args.model_name) - args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + # tts model + if model_item["model_type"] == "tts_models": + tts_path = model_path + tts_config_path = config_path + if "default_vocoder" in model_item: + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + + # voice conversion model + if model_item["model_type"] == "voice_conversion_models": + vc_path = model_path + vc_config_path = config_path + + # load vocoder if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) # CASE4: set custom model paths if args.model_path is not None: - model_path = args.model_path - config_path = args.config_path + tts_path = args.model_path + tts_config_path = args.config_path speakers_file_path = args.speakers_file_path language_ids_file_path = args.language_ids_file_path @@ -314,14 +350,16 @@ If you don't specify any models, then it uses LJSpeech based English model. # load models synthesizer = Synthesizer( - model_path, - config_path, + tts_path, + tts_config_path, speakers_file_path, language_ids_file_path, vocoder_path, vocoder_config_path, encoder_path, encoder_config_path, + vc_path, + vc_config_path, args.use_cuda, ) @@ -354,16 +392,22 @@ If you don't specify any models, then it uses LJSpeech based English model. print(" > Text: {}".format(args.text)) # kick it - wav = synthesizer.tts( - args.text, - args.speaker_idx, - args.language_idx, - args.speaker_wav, - reference_wav=args.reference_wav, - style_wav=args.capacitron_style_wav, - style_text=args.capacitron_style_text, - reference_speaker_name=args.reference_speaker_idx, - ) + if tts_path is not None: + wav = synthesizer.tts( + args.text, + args.speaker_idx, + args.language_idx, + args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + elif vc_path is not None: + wav = synthesizer.voice_conversion( + source_wav=args.source_wav, + target_wav=args.target_wav, + ) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 067c32d9..49ac8f4a 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit: """ config_class = None config_name = model_name + "_config" - paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"] + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] for path in paths: try: config_class = find_module(path, config_name) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 37a09354..bda0fc9e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -27,6 +27,8 @@ class BaseTTS(BaseTrainerModel): It defines common `tts` specific functions on top of `Model` implementation. """ + MODEL_TYPE = "tts" + def __init__( self, config: Coqpit, diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index d5312f49..b1c16468 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -85,6 +85,7 @@ def to_camel(text): text = text.capitalize() text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) text = text.replace("Tts", "TTS") + text = text.replace("vc", "VC") return text diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index ef4c11f5..8419429d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -185,6 +185,13 @@ class ModelManager(object): """ return self._list_for_model_type("vocoder_models") + def list_vc_models(self): + """Print all the voice conversion models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("voice_conversion_models") + def list_langs(self): """Print all the available languages""" print(" Name format: type/language") @@ -234,6 +241,7 @@ class ModelManager(object): model_type, lang, dataset, model = model_name.split("/") model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index c197b1f5..6c368a2b 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -12,6 +12,8 @@ from TTS.tts.models import setup_model as setup_tts_model # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.models import setup_model as setup_vc_model from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input @@ -19,14 +21,16 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input class Synthesizer(object): def __init__( self, - tts_checkpoint: str, - tts_config_path: str, + tts_checkpoint: str = "", + tts_config_path: str = "", tts_speakers_file: str = "", tts_languages_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", encoder_checkpoint: str = "", encoder_config: str = "", + vc_checkpoint: str = "", + vc_config: str = "", use_cuda: bool = False, ) -> None: """General 🐸 TTS interface for inference. It takes a tts and a vocoder @@ -41,12 +45,14 @@ class Synthesizer(object): TODO: set the segmenter based on the source language Args: - tts_checkpoint (str): path to the tts model file. - tts_config_path (str): path to the tts config file. + tts_checkpoint (str, optional): path to the tts model file. + tts_config_path (str, optional): path to the tts config file. vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. vocoder_config (str, optional): path to the vocoder config file. Defaults to None. encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, + vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`, + vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ self.tts_checkpoint = tts_checkpoint @@ -57,10 +63,13 @@ class Synthesizer(object): self.vocoder_config = vocoder_config self.encoder_checkpoint = encoder_checkpoint self.encoder_config = encoder_config + self.vc_checkpoint = vc_checkpoint + self.vc_config = vc_config self.use_cuda = use_cuda self.tts_model = None self.vocoder_model = None + self.vc_model = None self.speaker_manager = None self.tts_speakers = {} self.language_manager = None @@ -72,12 +81,19 @@ class Synthesizer(object): if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." - self._load_tts(tts_checkpoint, tts_config_path, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] + + if tts_checkpoint: + self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + if vocoder_checkpoint: self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self.output_sample_rate = self.vocoder_config.audio["sample_rate"] + if vc_checkpoint: + self._load_vc(vc_checkpoint, vc_config, use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + @staticmethod def _get_segmenter(lang: str): """get the sentence segmenter for the given language. @@ -90,6 +106,26 @@ class Synthesizer(object): """ return pysbd.Segmenter(language=lang, clean=True) + def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None: + """Load the voice conversion model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + vc_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.vc_config = load_config(vc_config_path) + self.vc_model = setup_vc_model(config=self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) + if use_cuda: + self.vc_model.cuda() + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: """Load the TTS model. @@ -168,7 +204,11 @@ class Synthesizer(object): path (str): output path to save the waveform. """ wav = np.array(wav) - self.tts_model.ap.save_wav(wav, path, self.output_sample_rate) + save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate) + + def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]: + output_wav = self.vc_model.voice_conversion(source_wav, target_wav) + return output_wav def tts( self, diff --git a/TTS/vc/configs/__init__.py b/TTS/vc/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py new file mode 100644 index 00000000..890a2693 --- /dev/null +++ b/TTS/vc/configs/freevc_config.py @@ -0,0 +1,5 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.vc.configs.shared_configs import BaseVCConfig +from TTS.vc.models.freevc import FreeVCArgs, FreeVCAudioConfig, FreeVCConfig diff --git a/TTS/vc/configs/shared_configs.py b/TTS/vc/configs/shared_configs.py new file mode 100644 index 00000000..74164a74 --- /dev/null +++ b/TTS/vc/configs/shared_configs.py @@ -0,0 +1,155 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseVCConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py new file mode 100644 index 00000000..5a09b4e5 --- /dev/null +++ b/TTS/vc/models/__init__.py @@ -0,0 +1,17 @@ +import importlib +import re +from typing import Dict, List, Union + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "model" in config and config["model"].lower() == "freevc": + MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC + model = MyModel.init_from_config(config, samples) + return model diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py new file mode 100644 index 00000000..19f2761b --- /dev/null +++ b/TTS/vc/models/base_vc.py @@ -0,0 +1,429 @@ +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +# pylint: skip-file + + +class BaseVC(BaseTrainerModel): + """Base `vc` class. Every new `vc` model must inherit this. + + It defines common `vc` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "vc" + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `VCDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=None, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `vc` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.pth` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py new file mode 100644 index 00000000..4aa26724 --- /dev/null +++ b/TTS/vc/models/freevc.py @@ -0,0 +1,833 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import librosa +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +import TTS.vc.modules.freevc.commons as commons +import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.io import load_fsspec, save_checkpoint +from TTS.vc.configs.shared_configs import BaseVCConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch +from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx +from TTS.vc.modules.freevc.wavlm import get_wavlm + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SpeakerEncoder(torch.nn.Module): + def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:, -partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) + mels = list(mel[:, s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) + # embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + + +@dataclass +class FreeVCAudioConfig(Coqpit): + """Audio configuration + + Args: + max_wav_value (float): + The maximum value of the waveform. + + input_sample_rate (int): + The sampling rate of the input waveform. + + output_sample_rate (int): + The sampling rate of the output waveform. + + filter_length (int): + The length of the filter. + + hop_length (int): + The hop length. + + win_length (int): + The window length. + + n_mel_channels (int): + The number of mel channels. + + mel_fmin (float): + The minimum frequency of the mel filterbank. + + mel_fmax (Optional[float]): + The maximum frequency of the mel filterbank. + """ + + max_wav_value: float = field(default=32768.0) + input_sample_rate: int = field(default=16000) + output_sample_rate: int = field(default=24000) + filter_length: int = field(default=1280) + hop_length: int = field(default=320) + win_length: int = field(default=1280) + n_mel_channels: int = field(default=80) + mel_fmin: float = field(default=0.0) + mel_fmax: Optional[float] = field(default=None) + + +@dataclass +class FreeVCArgs(Coqpit): + """FreeVC model arguments + + Args: + spec_channels (int): + The number of channels in the spectrogram. + + inter_channels (int): + The number of channels in the intermediate layers. + + hidden_channels (int): + The number of channels in the hidden layers. + + filter_channels (int): + The number of channels in the filter layers. + + n_heads (int): + The number of attention heads. + + n_layers (int): + The number of layers. + + kernel_size (int): + The size of the kernel. + + p_dropout (float): + The dropout probability. + + resblock (str): + The type of residual block. + + resblock_kernel_sizes (List[int]): + The kernel sizes for the residual blocks. + + resblock_dilation_sizes (List[List[int]]): + The dilation sizes for the residual blocks. + + upsample_rates (List[int]): + The upsample rates. + + upsample_initial_channel (int): + The number of channels in the initial upsample layer. + + upsample_kernel_sizes (List[int]): + The kernel sizes for the upsample layers. + + n_layers_q (int): + The number of layers in the quantization network. + + use_spectral_norm (bool): + Whether to use spectral normalization. + + gin_channels (int): + The number of channels in the global conditioning vector. + + ssl_dim (int): + The dimension of the self-supervised learning embedding. + + use_spk (bool): + Whether to use external speaker encoder. + """ + + spec_channels: int = field(default=641) + inter_channels: int = field(default=192) + hidden_channels: int = field(default=192) + filter_channels: int = field(default=768) + n_heads: int = field(default=2) + n_layers: int = field(default=6) + kernel_size: int = field(default=3) + p_dropout: float = field(default=0.1) + resblock: str = field(default="1") + resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2]) + upsample_initial_channel: int = field(default=512) + upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + n_layers_q: int = field(default=3) + use_spectral_norm: bool = field(default=False) + gin_channels: int = field(default=256) + ssl_dim: int = field(default=1024) + use_spk: bool = field(default=False) + num_spks: int = field(default=0) + segment_size: int = field(default=8960) + + +class FreeVC(BaseVC): + """ + + Papaer:: + https://arxiv.org/abs/2210.15418# + + Paper Abstract:: + Voice conversion (VC) can be achieved by first extracting source content information and target speaker + information, and then reconstructing waveform with these information. However, current approaches normally + either extract dirty content information with speaker information leaked in, or demand a large amount of + annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the + mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for + high-quality waveform reconstruction, and propose strategies for clean content information extraction without + text annotation. We disentangle content information by imposing an information bottleneck to WavLM features, + and propose the spectrogram-resize based data augmentation to improve the purity of extracted content + information. Experimental results show that the proposed method outperforms the latest VC models trained with + annotated data and has greater robustness. + + Original Code:: + https://github.com/OlaWod/FreeVC + + Examples: + >>> from TTS.vc.configs.freevc_config import FreeVCConfig + >>> from TTS.vc.models.freevc import FreeVC + >>> config = FreeVCConfig() + >>> model = FreeVC(config) + """ + + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.spec_channels = self.args.spec_channels + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.segment_size = self.args.segment_size + self.gin_channels = self.args.gin_channels + self.ssl_dim = self.args.ssl_dim + self.use_spk = self.args.use_spk + + self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16) + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = Encoder( + self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels + ) + self.flow = ResidualCouplingBlock( + self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels + ) + if not self.use_spk: + self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels) + else: + self.load_pretrained_speaker_encoder() + + self.wavlm = get_wavlm() + + @property + def device(self): + return next(self.parameters()).device + + def load_pretrained_speaker_encoder(self): + """Load pretrained speaker encoder model as mentioned in the paper.""" + print(" > Loading pretrained speaker encoder model ...") + self.enc_spk_ex = SpeakerEncoderEx( + "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" + ) + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = self.args.num_spks + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_spks + + def forward( + self, + c: torch.Tensor, + spec: torch.Tensor, + g: Optional[torch.Tensor] = None, + mel: Optional[torch.Tensor] = None, + c_lengths: Optional[torch.Tensor] = None, + spec_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ]: + """ + Forward pass of the model. + + Args: + c: WavLM features. Shape: (batch_size, c_seq_len). + spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + g: The speaker embedding. Shape: (batch_size, spk_emb_dim). + mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths: The lengths of the WavLM features. Shape: (batch_size,). + spec_lengths: The lengths of the spectrogram. Shape: (batch_size,). + + Returns: + o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + ids_slice: The slice indices. Shape: (batch_size, num_slices). + spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables. + """ + + # If c_lengths is None, set it to the length of the last dimension of c + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + # If spec_lengths is None, set it to the length of the last dimension of spec + if spec_lengths is None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + # If use_spk is False, compute g from mel using enc_spk + g = None + if not self.use_spk: + g = self.enc_spk(mel).unsqueeze(-1) + + # Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g) + + # Compute z_p using flow + z_p = self.flow(z, spec_mask, g=g) + + # Randomly slice z and compute o using dec + z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + @torch.no_grad() + def inference(self, c, g=None, mel=None, c_lengths=None): + """ + Inference pass of the model + + Args: + c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + + Returns: + torch.Tensor: Output tensor. + """ + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel) + g = g.unsqueeze(-1) + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + return o + + def extract_wavlm_features(self, y): + """Extract WavLM features from an audio tensor. + + Args: + y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len). + """ + + with torch.no_grad(): + c = self.wavlm.extract_features(y)[0] + c = c.transpose(1, 2) + return c + + def load_audio(self, wav): + """Read and format the input audio.""" + if isinstance(wav, str): + wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate) + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav).to(self.device) + if isinstance(wav, torch.Tensor): + wav = wav.to(self.device) + if isinstance(wav, list): + wav = torch.from_numpy(np.array(wav)).to(self.device) + return wav.float() + + @torch.inference_mode() + def voice_conversion(self, src, tgt): + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + torch.Tensor: Output tensor. + """ + + wav_tgt = self.load_audio(tgt).cpu().numpy() + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + + if self.config.model_args.use_spk: + g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device) + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device) + mel_tgt = mel_spectrogram_torch( + wav_tgt, + self.config.audio.filter_length, + self.config.audio.n_mel_channels, + self.config.audio.input_sample_rate, + self.config.audio.hop_length, + self.config.audio.win_length, + self.config.audio.mel_fmin, + self.config.audio.mel_fmax, + ) + # src + wav_src = self.load_audio(src) + c = self.extract_wavlm_features(wav_src[None, :]) + + if self.config.model_args.use_spk: + audio = self.inference(c, g=g_tgt) + else: + audio = self.inference(c, mel=mel_tgt.transpose(1, 2)) + audio = audio[0][0].data.cpu().float().numpy() + return audio + + def eval_step(): + ... + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + model = FreeVC(config) + return model + + def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def train_step(): + ... + + +@dataclass +class FreeVCConfig(BaseVCConfig): + """Defines parameters for FreeVC End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (FreeVCArgs): + Model architecture arguments. Defaults to `FreeVCArgs()`. + + audio (FreeVCAudioConfig): + Audio processing configuration. Defaults to `FreeVCAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.freevc_config import FreeVCConfig + >>> config = FreeVCConfig() + """ + + model: str = "freevc" + # model specific params + model_args: FreeVCArgs = FreeVCArgs() + audio: FreeVCAudioConfig = FreeVCAudioConfig() + + # optimizer + # TODO with training support + + # loss params + # TODO with training support + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + speakers_file: str = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/vc/modules/__init__.py b/TTS/vc/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/vc/modules/freevc/__init__.py b/TTS/vc/modules/freevc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py new file mode 100644 index 00000000..5684a88e --- /dev/null +++ b/TTS/vc/modules/freevc/commons.py @@ -0,0 +1,170 @@ +import math + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py new file mode 100644 index 00000000..2dcbf214 --- /dev/null +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -0,0 +1,125 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py new file mode 100644 index 00000000..0503a13c --- /dev/null +++ b/TTS/vc/modules/freevc/modules.py @@ -0,0 +1,391 @@ +import copy +import math + +import numpy as np +import scipy +import torch +from torch import nn +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +import TTS.vc.modules.freevc.commons as commons +from TTS.vc.modules.freevc.commons import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/TTS/vc/modules/freevc/speaker_encoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/modules/freevc/speaker_encoder/audio.py new file mode 100644 index 00000000..52f6fd08 --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/audio.py @@ -0,0 +1,65 @@ +import struct +from pathlib import Path +from typing import Optional, Union + +# import webrtcvad +import librosa +import numpy as np +from scipy.ndimage.morphology import binary_dilation + +from TTS.vc.modules.freevc.speaker_encoder.hparams import * + +int16_max = (2**15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(fpath_or_wav, sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels, + ) + return frames.astype(np.float32).T + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/TTS/vc/modules/freevc/speaker_encoder/hparams.py new file mode 100644 index 00000000..2c536ae1 --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/hparams.py @@ -0,0 +1,31 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py new file mode 100644 index 00000000..2e21a14f --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -0,0 +1,175 @@ +from pathlib import Path +from time import perf_counter as timer +from typing import List, Union + +import numpy as np +import torch +from torch import nn + +from TTS.utils.io import load_fsspec +from TTS.vc.modules.freevc.speaker_encoder import audio +from TTS.vc.modules.freevc.speaker_encoder.hparams import * + + +class SpeakerEncoder(nn.Module): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): + """ + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will + run on cpu. Outputs are always returned on the cpu, as numpy arrays. + """ + super().__init__() + + # Define the network + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + # Get the target device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + + # Load the pretrained model'speaker weights + # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") + # if not weights_fpath.exists(): + # raise Exception("Couldn't find the voice encoder pretrained model at %s." % + # weights_fpath) + + start = timer() + checkpoint = load_fsspec(weights_fpath, map_location="cpu") + + self.load_state_dict(checkpoint["model_state"], strict=False) + self.to(device) + + if verbose: + print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start)) + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of utterance spectrograms. + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. + """ + # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + @staticmethod + def compute_partial_slices(n_samples: int, rate, min_coverage): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform + correspond to its spectrogram. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wav_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 < min_coverage <= 1 + + # Compute how many frames separate two partial utterances + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % ( + sampling_rate / (samples_per_frame * partials_n_frames) + ) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partials_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partials_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): + """ + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the + L2-normed average embedding of the partial utterances. + + TODO: independent batched version of this function + + :param wav: a preprocessed utterance waveform as a numpy array of float32 + :param return_partials: if True, the partial embeddings will also be returned along with + the wav slices corresponding to each partial utterance. + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. + """ + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials and forward them through the model + mel = audio.wav_to_mel_spectrogram(wav) + mels = np.array([mel[s] for s in mel_slices]) + with torch.no_grad(): + mels = torch.from_numpy(mels).to(self.device) + partial_embeds = self(mels).cpu().numpy() + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wav_slices + return embed + + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): + """ + Compute the embedding of a collection of wavs (presumably from the same speaker) by + averaging their embedding and L2-normalizing it. + + :param wavs: list of wavs a numpy arrays of float32. + :param kwargs: extra arguments to embed_utterance() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). + """ + raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0) + return raw_embed / np.linalg.norm(raw_embed, 2) diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py new file mode 100644 index 00000000..6edada40 --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -0,0 +1,35 @@ +import os +import urllib.request + +import torch + +from TTS.utils.generic_utils import get_user_data_dir +from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig + +model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" + + +def get_wavlm(device="cpu"): + """Download the model and return the model object.""" + + output_path = get_user_data_dir("tts") + + output_path = os.path.join(output_path, "wavlm") + if not os.path.exists(output_path): + os.makedirs(output_path) + + output_path = os.path.join(output_path, "WavLM-Large.pt") + if not os.path.exists(output_path): + print(f" > Downloading WavLM model to {output_path} ...") + urllib.request.urlretrieve(model_uri, output_path) + + checkpoint = torch.load(output_path, map_location=torch.device(device)) + cfg = WavLMConfig(checkpoint["cfg"]) + wavlm = WavLM(cfg).to(device) + wavlm.load_state_dict(checkpoint["model"]) + wavlm.eval() + return wavlm + + +if __name__ == "__main__": + wavlm = get_wavlm() diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/modules/freevc/wavlm/config.json new file mode 100644 index 00000000..c6f851b9 --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/config.json @@ -0,0 +1,99 @@ +{ + "_name_or_path": "./wavlm-large/", + "activation_dropout": 0.0, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "WavLMModel" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 768, + "contrastive_logits_temperature": 0.1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.0, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.075, + "mask_time_selection": "static", + "max_bucket_distance": 800, + "model_type": "wavlm", + "num_adapter_layers": 3, + "num_attention_heads": 16, + "num_buckets": 320, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_ctc_classes": 80, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 0, + "proj_codevector_dim": 768, + "replace_prob": 0.5, + "tokenizer_class": "Wav2Vec2CTCTokenizer", + "torch_dtype": "float32", + "transformers_version": "4.15.0.dev0", + "use_weighted_layer_sum": false, + "vocab_size": 32 + } \ No newline at end of file diff --git a/TTS/vc/modules/freevc/wavlm/modules.py b/TTS/vc/modules/freevc/wavlm/modules.py new file mode 100644 index 00000000..37c1a6e8 --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/modules.py @@ -0,0 +1,768 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function""" + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2] + else: + x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2]) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask.bernoulli_(p) + mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + + # scale weights and apply mask + mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size) + self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size) + + self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/modules/freevc/wavlm/wavlm.py new file mode 100644 index 00000000..7efb11bf --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/wavlm.py @@ -0,0 +1,719 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from TTS.vc.modules.freevc.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + ) + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + ) + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + # padding_mask = padding_mask.all(-1) + padding_mask = padding_mask.any(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default", + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride)) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1)) + self.conv_layers.append(torch.nn.LayerNorm([dim, idim])) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + self_attn_mask=streaming_mask, + pos_bias=pos_bias, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index 01a7ff68..0bcbe7ba 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -18,6 +18,8 @@ class BaseVocoder(BaseTrainerModel): - 1D tensors `batch x 1` """ + MODEL_TYPE = "vocoder" + def __init__(self, config): super().__init__() self._set_model_args(config) diff --git a/docs/source/inference.md b/docs/source/inference.md index 60f787d3..a69d714e 100644 --- a/docs/source/inference.md +++ b/docs/source/inference.md @@ -36,8 +36,8 @@ Run a tts and a vocoder model from the released model list. Note that not every ```bash tts --text "Text for TTS" \ - --model_name "///" \ - --vocoder_name "///" \ + --model_name "tts_models///" \ + --vocoder_name "vocoder_models///" \ --out_path folder/to/save/output.wav ``` @@ -64,8 +64,17 @@ tts --text "Text for TTS" \ Run a multi-speaker TTS model from the released models list. ```bash -tts --model_name "///" --list_speaker_idxs # list the possible speaker IDs. -tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx "" +tts --model_name "tts_models///" --list_speaker_idxs # list the possible speaker IDs. +tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "tts_models///" --speaker_idx "" +``` + +Run a released voice conversion model + +```bash +tts --model_name "voice_conversion///" + --source_wav "my/source/speaker/audio.wav" + --target_wav "my/target/speaker/audio.wav" + --out_path folder/to/save/output.wav ``` **Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder. @@ -135,4 +144,23 @@ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_ tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav") tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav") tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav") +``` + +Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav` + +```python +tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True) +tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") +``` + +Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can +clone voices by using any model in 🐸TTS. + +```python +tts = TTS("tts_models/de/thorsten/tacotron2-DDC") +tts.tts_with_vc_to_file( + "Wie sage ich auf Italienisch, dass ich dich liebe?", + speaker_wav="target/speaker.wav", + file_path="ouptut.wav" +) ``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2acefadc..f8730cc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ tqdm anyascii pyyaml fsspec>=2021.04.0 +aiohttp packaging # deps for examples flask diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index a44c98e8..4072edab 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -28,7 +28,7 @@ class TTSTest(unittest.TestCase): def test_multi_speaker_multi_lingual_model(self): tts = TTS() - tts.load_model_by_name(tts.models[0]) # YourTTS + tts.load_tts_model_by_name(tts.models[0]) # YourTTS tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH) self.assertTrue(tts.is_multi_speaker) @@ -38,5 +38,5 @@ class TTSTest(unittest.TestCase): def test_voice_cloning(self): # pylint: disable=no-self-use tts = TTS() - tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts") + tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts") tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH) diff --git a/tests/vc_tests/__init__.py b/tests/vc_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py new file mode 100644 index 00000000..a4a4f726 --- /dev/null +++ b/tests/vc_tests/test_freevc.py @@ -0,0 +1,135 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path +from TTS.vc.configs.freevc_config import FreeVCConfig +from TTS.vc.models.freevc import FreeVC + +# pylint: disable=unused-variable +# pylint: disable=no-self-use + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = FreeVCConfig() + +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") +BATCH_SIZE = 3 + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class TestFreeVC(unittest.TestCase): + def _create_inputs(self, config, batch_size=2): + input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device) + input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device) + input_lengths[-1] = 30 * config.audio["hop_length"] + spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device) + mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device) + spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + spec_lengths[-1] = spec.size(2) + waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device) + return input_dummy, input_lengths, mel, spec, spec_lengths, waveform + + @staticmethod + def _create_inputs_inference(): + source_wav = torch.rand(16000) + target_wav = torch.rand(16000) + return source_wav, target_wav + + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) + count += 1 + + def test_methods(self): + config = FreeVCConfig() + model = FreeVC(config).to(device) + model.load_pretrained_speaker_encoder() + model.init_multispeaker(config) + wavlm_feats = model.extract_wavlm_features(torch.rand(1, 16000)) + assert wavlm_feats.shape == (1, 1024, 49), wavlm_feats.shape + + def test_load_audio(self): + config = FreeVCConfig() + model = FreeVC(config).to(device) + wav = model.load_audio(WAV_FILE) + wav2 = model.load_audio(wav) + assert all(torch.isclose(wav, wav2)) + + def _test_forward(self, batch_size): + # create model + config = FreeVCConfig() + model = FreeVC(config).to(device) + model.train() + print(" > Num parameters for FreeVC model:%s" % (count_parameters(model))) + + _, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size) + + wavlm_vec = model.extract_wavlm_features(waveform) + wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) + + y = model.forward(wavlm_vec, spec, None, mel, spec_lengths, wavlm_vec_lengths) + # TODO: assert with training implementation + + def test_forward(self): + self._test_forward(1) + self._test_forward(3) + + def _test_inference(self, batch_size): + config = FreeVCConfig() + model = FreeVC(config).to(device) + model.eval() + + _, _, mel, _, _, waveform = self._create_inputs(config, batch_size) + + wavlm_vec = model.extract_wavlm_features(waveform) + wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) + + output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths) + assert ( + output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1] + ), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}" + + def test_inference(self): + self._test_inference(1) + self._test_inference(3) + + def test_voice_conversion(self): + config = FreeVCConfig() + model = FreeVC(config).to(device) + model.eval() + + source_wav, target_wav = self._create_inputs_inference() + output_wav = model.voice_conversion(source_wav, target_wav) + assert ( + output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0] + ), f"{output_wav.shape} != {source_wav.shape}" + + def test_train_step(self): + ... + + def test_train_eval_log(self): + ... + + def test_test_run(self): + ... + + def test_load_checkpoint(self): + ... + + def test_get_criterion(self): + ... + + def test_init_from_config(self): + ... diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index a1f9c536..001f5ef6 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -51,6 +51,13 @@ def run_models(offset=0, step=1): # remove downloaded models shutil.rmtree(local_download_dir) shutil.rmtree(get_user_data_dir("tts")) + elif "voice_conversion_models" in model_name: + speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav") + run_cli( + f"tts --model_name {model_name} " + f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --progress_bar False' + ) else: # only download the model manager.download_model(model_name)