Fix dataset handling with the new embedding file keys (#1991)

This commit is contained in:
Edresson Casanova 2022-09-19 18:44:14 -03:00 committed by GitHub
parent 0a112f7841
commit 3faccbda97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 20 deletions

View File

@ -102,11 +102,9 @@ speaker_mapping = {}
for idx, fields in enumerate(tqdm(samples)): for idx, fields in enumerate(tqdm(samples)):
class_name = fields[class_name_key] class_name = fields[class_name_key]
audio_file = fields["audio_file"] audio_file = fields["audio_file"]
dataset_name = fields["dataset_name"] embedding_key = fields["audio_unique_name"]
root_path = fields["root_path"] root_path = fields["root_path"]
relfilepath = os.path.splitext(audio_file.replace(root_path, ""))[0]
embedding_key = f"{dataset_name}#{relfilepath}"
if args.old_file is not None and embedding_key in encoder_manager.clip_ids: if args.old_file is not None and embedding_key in encoder_manager.clip_ids:
# get the embedding from the old file # get the embedding from the old file
embedd = encoder_manager.get_embedding_by_clip(embedding_key) embedd = encoder_manager.get_embedding_by_clip(embedding_key)

View File

@ -59,6 +59,18 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
return items[:eval_split_size], items[eval_split_size:] return items[:eval_split_size], items[eval_split_size:]
def add_extra_keys(metadata, language, dataset_name):
for item in metadata:
# add language name
item["language"] = language
# add unique audio name
relfilepath = os.path.splitext(item["audio_file"].replace(item["root_path"], ""))[0]
audio_unique_name = f"{dataset_name}#{relfilepath}"
item["audio_unique_name"] = audio_unique_name
return metadata
def load_tts_samples( def load_tts_samples(
datasets: Union[List[Dict], Dict], datasets: Union[List[Dict], Dict],
eval_split=True, eval_split=True,
@ -111,15 +123,15 @@ def load_tts_samples(
# load train set # load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}" assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}"
meta_data_train = [{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_train]
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [ meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name)
{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_eval
]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval

View File

@ -256,6 +256,7 @@ class TTSDataset(Dataset):
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
"language_name": item["language"], "language_name": item["language"],
"wav_file_name": os.path.basename(item["audio_file"]), "wav_file_name": os.path.basename(item["audio_file"]),
"audio_unique_name": item["audio_unique_name"],
} }
return sample return sample
@ -397,8 +398,8 @@ class TTSDataset(Dataset):
language_ids = None language_ids = None
# get pre-computed d-vectors # get pre-computed d-vectors
if self.d_vector_mapping is not None: if self.d_vector_mapping is not None:
wav_files_names = list(batch["wav_file_name"]) embedding_keys = list(batch["audio_unique_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
else: else:
d_vectors = None d_vectors = None

View File

@ -284,6 +284,7 @@ class VitsDataset(TTSDataset):
"wav_file": wav_filename, "wav_file": wav_filename,
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
"language_name": item["language"], "language_name": item["language"],
"audio_unique_name": item["audio_unique_name"],
} }
@property @property
@ -308,6 +309,7 @@ class VitsDataset(TTSDataset):
- language_names: :math:`[B]` - language_names: :math:`[B]`
- audiofile_paths: :math:`[B]` - audiofile_paths: :math:`[B]`
- raw_texts: :math:`[B]` - raw_texts: :math:`[B]`
- audio_unique_names: :math:`[B]`
""" """
# convert list of dicts to dict of lists # convert list of dicts to dict of lists
B = len(batch) B = len(batch)
@ -348,6 +350,7 @@ class VitsDataset(TTSDataset):
"language_names": batch["language_name"], "language_names": batch["language_name"],
"audio_files": batch["wav_file"], "audio_files": batch["wav_file"],
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
"audio_unique_names": batch["audio_unique_name"],
} }
@ -1470,7 +1473,7 @@ class Vits(BaseTTS):
# get d_vectors from audio file names # get d_vectors from audio file names
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.embeddings d_vector_mapping = self.speaker_manager.embeddings
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]]
d_vectors = torch.FloatTensor(d_vectors) d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names # get language ids from language names

View File

@ -1,5 +1,5 @@
{ {
"LJ001-0001.wav": { "#/wavs/LJ001-0001": {
"name": "ljspeech-0", "name": "ljspeech-0",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -260,7 +260,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0002.wav": { "#/wavs/LJ001-0002": {
"name": "ljspeech-1", "name": "ljspeech-1",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -521,7 +521,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0003.wav": { "#/wavs/LJ001-0003": {
"name": "ljspeech-2", "name": "ljspeech-2",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -782,7 +782,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0004.wav": { "#/wavs/LJ001-0004": {
"name": "ljspeech-3", "name": "ljspeech-3",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -1043,7 +1043,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0005.wav": { "#/wavs/LJ001-0005": {
"name": "ljspeech-4", "name": "ljspeech-4",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -1304,7 +1304,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0006.wav": { "#/wavs/LJ001-0006": {
"name": "ljspeech-5", "name": "ljspeech-5",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -1565,7 +1565,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0007.wav": { "#/wavs/LJ001-0007": {
"name": "ljspeech-6", "name": "ljspeech-6",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -1826,7 +1826,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0008.wav": { "#/wavs/LJ001-0008": {
"name": "ljspeech-7", "name": "ljspeech-7",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -2087,7 +2087,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0009.wav": { "#/wavs/LJ001-0009": {
"name": "ljspeech-8", "name": "ljspeech-8",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,
@ -2348,7 +2348,7 @@
-0.09469571709632874 -0.09469571709632874
] ]
}, },
"LJ001-0010.wav": { "#/wavs/LJ001-0010": {
"name": "ljspeech-9", "name": "ljspeech-9",
"embedding": [ "embedding": [
0.05539746582508087, 0.05539746582508087,