mirror of https://github.com/coqui-ai/TTS.git
Unit tests fixs
This commit is contained in:
parent
631aec6e88
commit
247da8ef12
|
@ -37,6 +37,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
|
||||
|
||||
dataset = EncoderDataset(
|
||||
c,
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=c.voice_len,
|
||||
|
|
|
@ -46,8 +46,8 @@ class BaseEncoderConfig(BaseTrainingConfig):
|
|||
# data loader
|
||||
num_classes_in_batch: int = MISSING
|
||||
num_utter_per_class: int = MISSING
|
||||
eval_num_classes_in_batch: int = MISSING
|
||||
eval_num_utter_per_class: int = MISSING
|
||||
eval_num_classes_in_batch: int = None
|
||||
eval_num_utter_per_class: int = None
|
||||
|
||||
num_loader_workers: int = MISSING
|
||||
voice_len: float = 1.6
|
||||
|
|
|
@ -8,6 +8,7 @@ from TTS.encoder.utils.generic_utils import AugmentWAV
|
|||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
|
@ -25,6 +26,7 @@ class EncoderDataset(Dataset):
|
|||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
|
@ -62,9 +64,9 @@ class EncoderDataset(Dataset):
|
|||
|
||||
def __parse_items(self):
|
||||
class_to_utters = {}
|
||||
for i in self.items:
|
||||
path_ = i["audio_file"]
|
||||
speaker_ = i["speaker_name"]
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
|
||||
if class_name in class_to_utters.keys():
|
||||
class_to_utters[class_name].append(path_)
|
||||
else:
|
||||
|
@ -82,8 +84,8 @@ class EncoderDataset(Dataset):
|
|||
|
||||
new_items = []
|
||||
for item in self.items:
|
||||
path_ = item[1]
|
||||
class_name = item[2]
|
||||
path_ = item["audio_file"]
|
||||
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
|
||||
# ignore filtered classes
|
||||
if class_name not in classes:
|
||||
continue
|
||||
|
@ -94,6 +96,7 @@ class EncoderDataset(Dataset):
|
|||
new_items.append({"wav_file_path": path_, "class_name": class_name})
|
||||
|
||||
return classes, new_items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
|
|
@ -435,7 +435,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append([speaker_id, wav_file, emotion_id])
|
||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
||||
return items
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue