diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 885d66b3..f485514a 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,80 +1,46 @@ + import argparse -import glob import os -import torch from tqdm import tqdm -from TTS.config import BaseDatasetConfig, load_config -from TTS.speaker_encoder.utils.generic_utils import setup_model +from argparse import RawTextHelpFormatter +from TTS.config import load_config from TTS.tts.datasets import load_meta_data from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio import AudioProcessor parser = argparse.ArgumentParser( - description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.' + description="""Compute embedding vectors for each wav file in a dataset.\n\n""" + """ + Example runs: + python TTS/bin/compute_embeddings.py speaker_encoder_model.pth.tar speaker_encoder_config.json dataset_config.json embeddings_output_path/ + """, + formatter_class=RawTextHelpFormatter, ) -parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).") +parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") parser.add_argument( "config_path", type=str, - help="Path to config file for training.", + help="Path to model config file.", ) -parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file") -parser.add_argument("output_path", type=str, help="path for output speakers.json.") + parser.add_argument( - "--target_dataset", + "config_dataset_path", type=str, - default="", - help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.", + help="Path to dataset config file.", ) +parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) -parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|") +parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() +c_dataset = load_config(args.config_dataset_path) -c = load_config(args.config_path) -ap = AudioProcessor(**c["audio"]) +meta_data_train, meta_data_eval = load_meta_data(c_dataset.datasets, eval_split=args.eval) +wav_files = meta_data_train + meta_data_eval -data_path = args.data_path -split_ext = os.path.splitext(data_path) -sep = args.separator - -if args.target_dataset != "": - # if target dataset is defined - dataset_config = [ - BaseDatasetConfig(name=args.target_dataset, path=args.data_path, meta_file_train=None, meta_file_val=None), - ] - wav_files, _ = load_meta_data(dataset_config, eval_split=False) -else: - # if target dataset is not defined - if len(split_ext) > 0 and split_ext[1].lower() == ".csv": - # Parse CSV - print(f"CSV file: {data_path}") - with open(data_path) as f: - wav_path = os.path.join(os.path.dirname(data_path), "wavs") - wav_files = [] - print(f"Separator is: {sep}") - for line in f: - components = line.split(sep) - if len(components) != 2: - print("Invalid line") - continue - wav_file = os.path.join(wav_path, components[0] + ".wav") - # print(f'wav_file: {wav_file}') - if os.path.exists(wav_file): - wav_files.append(wav_file) - print(f"Count of wavs imported: {len(wav_files)}") - else: - # Parse all wav files in data_path - wav_files = glob.glob(data_path + "/**/*.wav", recursive=True) - -# define Encoder model -model = setup_model(c) -model.load_state_dict(torch.load(args.model_path)["model"]) -model.eval() -if args.use_cuda: - model.cuda() +speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda) # compute speaker embeddings speaker_mapping = {} @@ -85,27 +51,24 @@ for idx, wav_file in enumerate(tqdm(wav_files)): else: speaker_name = None - mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T - mel_spec = torch.FloatTensor(mel_spec[None, :, :]) - if args.use_cuda: - mel_spec = mel_spec.cuda() - embedd = model.compute_embedding(mel_spec) - embedd = embedd.detach().cpu().numpy() + # extract the embedding + embedd = speaker_manager.compute_d_vector_from_clip(wav_file) # create speaker_mapping if target dataset is defined wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name]["name"] = speaker_name - speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() + speaker_mapping[wav_file_name]["embedding"] = embedd if speaker_mapping: # save speaker_mapping if target dataset is defined - if ".json" not in args.output_path: + if '.json' not in args.output_path: mapping_file_path = os.path.join(args.output_path, "speakers.json") else: mapping_file_path = args.output_path + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) - speaker_manager = SpeakerManager() + # pylint: disable=W0212 speaker_manager._save_json(mapping_file_path, speaker_mapping) print("Speaker embeddings saved at:", mapping_file_path) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index b0159b86..1cbc5516 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -227,7 +227,7 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # load data instances - meta_data_train, meta_data_eval = load_meta_data(c.datasets) + meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=args.eval) # use eval and training partitions meta_data = meta_data_train + meta_data_eval @@ -271,6 +271,7 @@ if __name__ == "__main__": parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) args = parser.parse_args() c = load_config(args.config_path) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 9e62657f..c7c25d80 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,41 +1,42 @@ """Find all the unique characters in a dataset""" import argparse -import os from argparse import RawTextHelpFormatter - -from TTS.tts.datasets import _get_preprocessor_by_name +from TTS.tts.datasets import load_meta_data +from TTS.config import load_config def main(): # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Find all the unique characters or phonemes in a dataset.\n\n""" - """Target dataset must be defined in TTS.tts.datasets.formatters\n\n""" """ Example runs: - python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv + python TTS/bin/find_unique_chars.py --config_path config.json """, formatter_class=RawTextHelpFormatter, ) - parser.add_argument( - "--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.formatters." + "--config_path", type=str, help="Path to dataset config file.", required=True ) - - parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.") - args = parser.parse_args() - preprocessor = _get_preprocessor_by_name(args.dataset) - items = preprocessor(os.path.dirname(args.meta_file), os.path.basename(args.meta_file)) + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_meta_data(c.datasets, eval_split=True) + items = train_items + eval_items + texts = "".join(item[0] for item in items) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + print(f" > Number of unique characters: {len(chars)}") print(f" > Unique characters: {''.join(sorted(chars))}") print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") - + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") if __name__ == "__main__": main() diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index 05a56675..21439d6b 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -1,4 +1,5 @@ import torch +import numpy as np from torch import nn @@ -70,24 +71,32 @@ class LSTMSpeakerEncoder(nn.Module): d = torch.nn.functional.normalize(d, p=2, dim=1) return d - def compute_embedding(self, x, num_frames=160, overlap=0.5): + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): """ Generate embeddings for a batch of utterances x: 1xTxD """ - num_overlap = int(num_frames * overlap) max_len = x.shape[1] - embed = None - cur_iter = 0 - for offset in range(0, max_len, num_frames - num_overlap): - cur_iter += 1 - end_offset = min(x.shape[1], offset + num_frames) + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len-num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset+num_frames) frames = x[:, offset:end_offset] - if embed is None: - embed = self.inference(frames) - else: - embed += self.inference(frames) - return embed / cur_iter + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): """ @@ -110,9 +119,11 @@ class LSTMSpeakerEncoder(nn.Module): return embed / num_iters # pylint: disable=unused-argument, redefined-builtin - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False): + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): state = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() if eval: self.eval() assert not self.training diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index ce86b01f..29f3ae61 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -199,3 +199,12 @@ class ResNetSpeakerEncoder(nn.Module): embeddings = torch.mean(embeddings, dim=0, keepdim=True) return embeddings + + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training \ No newline at end of file diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index db7841f4..ef5299cb 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -202,16 +202,20 @@ def libri_tts(root_path, meta_files=None): items = [] if meta_files is None: meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) + else: + if isinstance(meta_files, str): + meta_files = [os.path.join(root_path, meta_files)] + for meta_file in meta_files: _meta_file = os.path.basename(meta_file).split(".")[0] - speaker_name = _meta_file.split("_")[0] - chapter_id = _meta_file.split("_")[1] - _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") with open(meta_file, "r") as ttf: for line in ttf: cols = line.split("\t") - wav_file = os.path.join(_root_path, cols[0] + ".wav") - text = cols[1] + file_name = cols[0] + speaker_name, chapter_id, *_ = cols[0].split("_") + _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") + wav_file = os.path.join(_root_path, file_name + ".wav") + text = cols[2] items.append([text, wav_file, "LTTS_" + speaker_name]) for item in items: assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" @@ -287,6 +291,17 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): return items +def mls(root_path, meta_files=None): + """http://www.openslr.org/94/""" + items = [] + with open(os.path.join(root_path, meta_files), "r") as meta: + for line in meta: + file, text = line.split('\t') + text = text[:-1] + speaker, book, *_ = file.split('_') + wav_file = os.path.join(root_path, os.path.dirname(meta_files), 'audio', speaker, book, file + ".wav") + items.append([text, wav_file, "MLS_" + speaker]) + return items # ======================================== VOX CELEB =========================================== def voxceleb2(root_path, meta_file=None): diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 8febcbbf..a8c9e0f6 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -59,6 +59,7 @@ class SpeakerManager: speaker_id_file_path: str = "", encoder_model_path: str = "", encoder_config_path: str = "", + use_cuda: bool = False, ): self.data_items = [] @@ -67,6 +68,7 @@ class SpeakerManager: self.clip_ids = [] self.speaker_encoder = None self.speaker_encoder_ap = None + self.use_cuda = use_cuda if data_items: self.speaker_ids, self.speaker_names, _ = self.parse_speakers_from_data(self.data_items) @@ -222,11 +224,11 @@ class SpeakerManager: """ self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) - self.speaker_encoder.load_checkpoint(config_path, model_path, True) + self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) # normalize the input audio level and trim silences - self.speaker_encoder_ap.do_sound_norm = True - self.speaker_encoder_ap.do_trim_silence = True + # self.speaker_encoder_ap.do_sound_norm = True + # self.speaker_encoder_ap.do_trim_silence = True def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list: """Compute a d_vector from a given audio file. @@ -242,6 +244,8 @@ class SpeakerManager: waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) spec = self.speaker_encoder_ap.melspectrogram(waveform) spec = torch.from_numpy(spec.T) + if self.use_cuda: + spec = spec.cuda() spec = spec.unsqueeze(0) d_vector = self.speaker_encoder.compute_embedding(spec) return d_vector @@ -272,6 +276,8 @@ class SpeakerManager: feats = torch.from_numpy(feats) if feats.ndim == 2: feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() return self.speaker_encoder.compute_embedding(feats) def run_umap(self): diff --git a/notebooks/PlotUmapLibriTTS.ipynb b/notebooks/PlotUmapLibriTTS.ipynb index 0448f3df..ec20383f 100644 --- a/notebooks/PlotUmapLibriTTS.ipynb +++ b/notebooks/PlotUmapLibriTTS.ipynb @@ -13,11 +13,7 @@ }, { "cell_type": "code", -<<<<<<< HEAD - "execution_count": 2, -======= "execution_count": null, ->>>>>>> dev "metadata": {}, "outputs": [], "source": [ @@ -29,15 +25,10 @@ "import umap\n", "\n", "from TTS.speaker_encoder.model import SpeakerEncoder\n", -<<<<<<< HEAD - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.utils.io import load_config\n", -======= "from TTS.utils.audio import AudioProcessor \n", "from TTS.tts.utils.generic_utils import load_config\n", ->>>>>>> dev "\n", "from bokeh.io import output_notebook, show\n", "from bokeh.plotting import figure\n", @@ -59,331 +50,9 @@ }, { "cell_type": "code", -<<<<<<< HEAD - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "\n", - "(function(root) {\n", - " function now() {\n", - " return new Date();\n", - " }\n", - "\n", - " var force = true;\n", - "\n", - " if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n", - " root._bokeh_onload_callbacks = [];\n", - " root._bokeh_is_loading = undefined;\n", - " }\n", - "\n", - " var JS_MIME_TYPE = 'application/javascript';\n", - " var HTML_MIME_TYPE = 'text/html';\n", - " var EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n", - " var CLASS_NAME = 'output_bokeh rendered_html';\n", - "\n", - " /**\n", - " * Render data to the DOM node\n", - " */\n", - " function render(props, node) {\n", - " var script = document.createElement(\"script\");\n", - " node.appendChild(script);\n", - " }\n", - "\n", - " /**\n", - " * Handle when an output is cleared or removed\n", - " */\n", - " function handleClearOutput(event, handle) {\n", - " var cell = handle.cell;\n", - "\n", - " var id = cell.output_area._bokeh_element_id;\n", - " var server_id = cell.output_area._bokeh_server_id;\n", - " // Clean up Bokeh references\n", - " if (id != null && id in Bokeh.index) {\n", - " Bokeh.index[id].model.document.clear();\n", - " delete Bokeh.index[id];\n", - " }\n", - "\n", - " if (server_id !== undefined) {\n", - " // Clean up Bokeh references\n", - " var cmd = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n", - " cell.notebook.kernel.execute(cmd, {\n", - " iopub: {\n", - " output: function(msg) {\n", - " var id = msg.content.text.trim();\n", - " if (id in Bokeh.index) {\n", - " Bokeh.index[id].model.document.clear();\n", - " delete Bokeh.index[id];\n", - " }\n", - " }\n", - " }\n", - " });\n", - " // Destroy server and session\n", - " var cmd = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n", - " cell.notebook.kernel.execute(cmd);\n", - " }\n", - " }\n", - "\n", - " /**\n", - " * Handle when a new output is added\n", - " */\n", - " function handleAddOutput(event, handle) {\n", - " var output_area = handle.output_area;\n", - " var output = handle.output;\n", - "\n", - " // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n", - " if ((output.output_type != \"display_data\") || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n", - " return\n", - " }\n", - "\n", - " var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n", - "\n", - " if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n", - " toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n", - " // store reference to embed id on output_area\n", - " output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n", - " }\n", - " if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n", - " var bk_div = document.createElement(\"div\");\n", - " bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n", - " var script_attrs = bk_div.children[0].attributes;\n", - " for (var i = 0; i < script_attrs.length; i++) {\n", - " toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n", - " }\n", - " // store reference to server id on output_area\n", - " output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n", - " }\n", - " }\n", - "\n", - " function register_renderer(events, OutputArea) {\n", - "\n", - " function append_mime(data, metadata, element) {\n", - " // create a DOM node to render to\n", - " var toinsert = this.create_output_subarea(\n", - " metadata,\n", - " CLASS_NAME,\n", - " EXEC_MIME_TYPE\n", - " );\n", - " this.keyboard_manager.register_events(toinsert);\n", - " // Render to node\n", - " var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n", - " render(props, toinsert[toinsert.length - 1]);\n", - " element.append(toinsert);\n", - " return toinsert\n", - " }\n", - "\n", - " /* Handle when an output is cleared or removed */\n", - " events.on('clear_output.CodeCell', handleClearOutput);\n", - " events.on('delete.Cell', handleClearOutput);\n", - "\n", - " /* Handle when a new output is added */\n", - " events.on('output_added.OutputArea', handleAddOutput);\n", - "\n", - " /**\n", - " * Register the mime type and append_mime function with output_area\n", - " */\n", - " OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n", - " /* Is output safe? */\n", - " safe: true,\n", - " /* Index of renderer in `output_area.display_order` */\n", - " index: 0\n", - " });\n", - " }\n", - "\n", - " // register the mime type if in Jupyter Notebook environment and previously unregistered\n", - " if (root.Jupyter !== undefined) {\n", - " var events = require('base/js/events');\n", - " var OutputArea = require('notebook/js/outputarea').OutputArea;\n", - "\n", - " if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n", - " register_renderer(events, OutputArea);\n", - " }\n", - " }\n", - "\n", - " \n", - " if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n", - " root._bokeh_timeout = Date.now() + 5000;\n", - " root._bokeh_failed_load = false;\n", - " }\n", - "\n", - " var NB_LOAD_WARNING = {'data': {'text/html':\n", - " \"\\n\"+\n", - " \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n", - " \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n", - " \"
\\n\"+\n", - " \"\\n\"+\n",
- " \"from bokeh.resources import INLINE\\n\"+\n",
- " \"output_notebook(resources=INLINE)\\n\"+\n",
- " \"
\\n\"+\n",
- " \"\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"
\\n\"+\n \"\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"
\\n\"+\n \"