diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 9ed459a2..ab5754f7 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,16 +1,16 @@ import argparse -import glob import os import torch import numpy as np from tqdm import tqdm +from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor -from TTS.config import load_config + parser = argparse.ArgumentParser( description='Compute embedding vectors for each wav file in a dataset.' diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 8fbc8f8e..fccbc311 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,6 +1,5 @@ """Find all the unique characters in a dataset""" import argparse -import os from argparse import RawTextHelpFormatter from TTS.tts.datasets.preprocess import load_meta_data from TTS.config import load_config @@ -31,7 +30,8 @@ def main(): texts = "".join(item[0] for item in items) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) - chars_force_lower = set([c.lower() for c in chars]) + chars_force_lower = [c.lower() for c in chars]) + chars_force_lower = set(chars_force_lower) print(f" > Number of unique characters: {len(chars)}") print(f" > Unique characters: {''.join(sorted(chars))}") diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 7fbc01b8..23d3f3c1 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -365,12 +365,11 @@ def mls(root_path, meta_files=None): """http://www.openslr.org/94/""" items = [] with open(os.path.join(root_path, meta_files), "r") as meta: - isTrain = "train" in meta_files for line in meta: file, text = line.split('\t') text = text[:-1] - speaker, book, no = file.split('_') - wav_file = os.path.join(root_path, "train" if isTrain else "dev", 'audio', speaker, book, file + ".wav") + speaker, book, *_ = file.split('_') + wav_file = os.path.join(root_path, os.path.dirname(meta_files), 'audio', speaker, book, file + ".wav") items.append([text, wav_file, "MLS_" + speaker]) return items diff --git a/tests/test_speaker_encoder.py b/tests/test_speaker_encoder.py index f56a9577..cecbd493 100644 --- a/tests/test_speaker_encoder.py +++ b/tests/test_speaker_encoder.py @@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase): assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" # compute d for a given batch dummy_input = T.rand(1, 240, 80) # B x T x D - output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5) + output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5) assert output.shape[0] == 1 assert output.shape[1] == 256 assert len(output.shape) == 2