mirror of https://github.com/coqui-ai/TTS.git
split dataset outside preprocessor
This commit is contained in:
parent
b7036e458d
commit
fd081c49b7
|
@ -146,7 +146,7 @@ def common_voice(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def libri_tts(root_path, meta_files=None, is_eval=False):
|
def libri_tts(root_path, meta_files=None):
|
||||||
"""https://ai.google/tools/datasets/libri-tts/"""
|
"""https://ai.google/tools/datasets/libri-tts/"""
|
||||||
items = []
|
items = []
|
||||||
if meta_files is None:
|
if meta_files is None:
|
||||||
|
@ -164,6 +164,4 @@ def libri_tts(root_path, meta_files=None, is_eval=False):
|
||||||
items.append([text, wav_file, speaker_name])
|
items.append([text, wav_file, speaker_name])
|
||||||
for item in items:
|
for item in items:
|
||||||
assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}"
|
assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}"
|
||||||
if meta_files is None:
|
|
||||||
return items[:500] if is_eval else items[500:]
|
|
||||||
return items
|
return items
|
13
train.py
13
train.py
|
@ -21,7 +21,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
load_config, lr_decay,
|
load_config, lr_decay,
|
||||||
remove_experiment_folder, save_best_model,
|
remove_experiment_folder, save_best_model,
|
||||||
save_checkpoint, sequence_mask, weight_decay,
|
save_checkpoint, sequence_mask, weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model)
|
set_init_dict, copy_config_file, setup_model,
|
||||||
|
split_dataset)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
get_speakers
|
get_speakers
|
||||||
|
@ -44,15 +45,15 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
global meta_data_train
|
global meta_data_train
|
||||||
global meta_data_eval
|
global meta_data_eval
|
||||||
if "meta_data_train" not in globals():
|
if "meta_data_train" not in globals():
|
||||||
if c.meta_file_train:
|
if c.meta_file_train is not None:
|
||||||
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
|
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
|
||||||
else:
|
else:
|
||||||
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=False)
|
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
|
||||||
if "meta_data_eval" not in globals():
|
if "meta_data_eval" not in globals() and c.run_eval:
|
||||||
if c.meta_file_val:
|
if c.meta_file_val is not None:
|
||||||
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val)
|
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val)
|
||||||
else:
|
else:
|
||||||
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=True)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
||||||
import subprocess
|
import subprocess
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict, Counter
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
|
|
||||||
|
@ -287,3 +287,26 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
location_attn=c.location_attn,
|
location_attn=c.location_attn,
|
||||||
separate_stopnet=c.separate_stopnet)
|
separate_stopnet=c.separate_stopnet)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(items):
|
||||||
|
is_multi_speaker = False
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
is_multi_speaker = len(set(speakers)) > 1
|
||||||
|
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
|
||||||
|
np.random.seed(0)
|
||||||
|
np.random.shuffle(items)
|
||||||
|
if is_multi_speaker:
|
||||||
|
items_eval = []
|
||||||
|
# most stupid code ever -- Fix it !
|
||||||
|
while len(items_eval) < eval_split_size:
|
||||||
|
speakers = [item[-1] for item in items]
|
||||||
|
speaker_counter = Counter(speakers)
|
||||||
|
item_idx = np.random.randint(0, len(items))
|
||||||
|
if speaker_counter[items[item_idx][-1]] > 1:
|
||||||
|
items_eval.append(items[item_idx])
|
||||||
|
del items[item_idx]
|
||||||
|
return items_eval, items
|
||||||
|
else:
|
||||||
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue