diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index caf86a42..13f9368b 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -37,6 +37,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch dataset = EncoderDataset( + c, ap, meta_data_eval if is_val else meta_data_train, voice_len=c.voice_len, diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py index 66947b2c..50164eaf 100644 --- a/TTS/encoder/configs/base_encoder_config.py +++ b/TTS/encoder/configs/base_encoder_config.py @@ -46,8 +46,8 @@ class BaseEncoderConfig(BaseTrainingConfig): # data loader num_classes_in_batch: int = MISSING num_utter_per_class: int = MISSING - eval_num_classes_in_batch: int = MISSING - eval_num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None num_loader_workers: int = MISSING voice_len: float = 1.6 diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index ef24daa1..47b62393 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -8,6 +8,7 @@ from TTS.encoder.utils.generic_utils import AugmentWAV class EncoderDataset(Dataset): def __init__( self, + config, ap, meta_data, voice_len=1.6, @@ -25,6 +26,7 @@ class EncoderDataset(Dataset): verbose (bool): print diagnostic information. """ super().__init__() + self.config = config self.items = meta_data self.sample_rate = ap.sample_rate self.seq_len = int(voice_len * self.sample_rate) @@ -62,9 +64,9 @@ class EncoderDataset(Dataset): def __parse_items(self): class_to_utters = {} - for i in self.items: - path_ = i["audio_file"] - speaker_ = i["speaker_name"] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] if class_name in class_to_utters.keys(): class_to_utters[class_name].append(path_) else: @@ -82,8 +84,8 @@ class EncoderDataset(Dataset): new_items = [] for item in self.items: - path_ = item[1] - class_name = item[2] + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] # ignore filtered classes if class_name not in classes: continue @@ -94,6 +96,7 @@ class EncoderDataset(Dataset): new_items.append({"wav_file_path": path_, "class_name": class_name}) return classes, new_items + def __len__(self): return len(self.items) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index fbfc1c25..554dede1 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -435,7 +435,7 @@ def emotion(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append([speaker_id, wav_file, emotion_id]) + items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) return items