Unit tests fixs

This commit is contained in:
Edresson Casanova 2022-03-07 19:26:24 -03:00
parent 631aec6e88
commit 247da8ef12
4 changed files with 12 additions and 8 deletions

View File

@ -37,6 +37,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
dataset = EncoderDataset(
c,
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len,

View File

@ -46,8 +46,8 @@ class BaseEncoderConfig(BaseTrainingConfig):
# data loader
num_classes_in_batch: int = MISSING
num_utter_per_class: int = MISSING
eval_num_classes_in_batch: int = MISSING
eval_num_utter_per_class: int = MISSING
eval_num_classes_in_batch: int = None
eval_num_utter_per_class: int = None
num_loader_workers: int = MISSING
voice_len: float = 1.6

View File

@ -8,6 +8,7 @@ from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset):
def __init__(
self,
config,
ap,
meta_data,
voice_len=1.6,
@ -25,6 +26,7 @@ class EncoderDataset(Dataset):
verbose (bool): print diagnostic information.
"""
super().__init__()
self.config = config
self.items = meta_data
self.sample_rate = ap.sample_rate
self.seq_len = int(voice_len * self.sample_rate)
@ -62,9 +64,9 @@ class EncoderDataset(Dataset):
def __parse_items(self):
class_to_utters = {}
for i in self.items:
path_ = i["audio_file"]
speaker_ = i["speaker_name"]
for item in self.items:
path_ = item["audio_file"]
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
if class_name in class_to_utters.keys():
class_to_utters[class_name].append(path_)
else:
@ -82,8 +84,8 @@ class EncoderDataset(Dataset):
new_items = []
for item in self.items:
path_ = item[1]
class_name = item[2]
path_ = item["audio_file"]
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
# ignore filtered classes
if class_name not in classes:
continue
@ -94,6 +96,7 @@ class EncoderDataset(Dataset):
new_items.append({"wav_file_path": path_, "class_name": class_name})
return classes, new_items
def __len__(self):
return len(self.items)

View File

@ -435,7 +435,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append([speaker_id, wav_file, emotion_id])
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
return items