From 33ac13e44ea34dc2adeb98755cdbefbfab79f0ea Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 1 Mar 2022 10:34:05 -0300 Subject: [PATCH] Add class map in emotion config --- TTS/bin/train_encoder.py | 14 ++++++++++---- TTS/encoder/dataset.py | 3 +++ TTS/encoder/emotion_encoder_config.py | 1 + 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index e4ab6a1d..014c7cb2 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -18,7 +18,7 @@ from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.io import load_fsspec +from TTS.utils.io import load_fsspec, copy_model_files from TTS.utils.radam import RAdam from TTS.utils.training import check_update @@ -56,8 +56,9 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False shuffle=False, num_workers=c.num_loader_workers, collate_fn=dataset.collate_fn, - ) - return loader, dataset.get_num_classes() + ) + + return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname() def train(model, optimizer, scheduler, criterion, data_loader, global_step): @@ -160,7 +161,7 @@ def main(args): # pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False) - data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True) + data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) if c.loss == "ge2e": criterion = GE2ELoss(loss_method="softmax") @@ -168,6 +169,11 @@ def main(args): # pylint: disable=redefined-outer-name criterion = AngleProtoLoss() elif c.loss == "softmaxproto": criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + if c.model == "emotion_encoder": + # update config with the class map + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH) + print(OUT_PATH) else: raise Exception("The %s not is a loss supported" % c.loss) diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 515a2128..2c777a6a 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -102,6 +102,9 @@ class EncoderDataset(Dataset): def get_num_classes(self): return len(self.classes) + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + def __sample_class(self, ignore_classes=None): class_name = random.sample(self.classes, 1)[0] # if list of classes_id is provide make sure that it's will be ignored diff --git a/TTS/encoder/emotion_encoder_config.py b/TTS/encoder/emotion_encoder_config.py index 2957c6c9..87e90d1b 100644 --- a/TTS/encoder/emotion_encoder_config.py +++ b/TTS/encoder/emotion_encoder_config.py @@ -8,6 +8,7 @@ class EmotionEncoderConfig(SpeakerEncoderConfig): """Defines parameters for Speaker Encoder model.""" model: str = "emotion_encoder" + map_classid_to_classname: dict = None def check_values(self): super().check_values()