Add class map in emotion config

This commit is contained in:
Edresson Casanova 2022-03-01 10:34:05 -03:00
parent 854c887764
commit 33ac13e44e
3 changed files with 14 additions and 4 deletions

View File

@ -18,7 +18,7 @@ from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.io import load_fsspec, copy_model_files
from TTS.utils.radam import RAdam
from TTS.utils.training import check_update
@ -56,8 +56,9 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
shuffle=False,
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn,
)
return loader, dataset.get_num_classes()
)
return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname()
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
@ -160,7 +161,7 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True)
data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
@ -168,6 +169,11 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
if c.model == "emotion_encoder":
# update config with the class map
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
print(OUT_PATH)
else:
raise Exception("The %s not is a loss supported" % c.loss)

View File

@ -102,6 +102,9 @@ class EncoderDataset(Dataset):
def get_num_classes(self):
return len(self.classes)
def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
def __sample_class(self, ignore_classes=None):
class_name = random.sample(self.classes, 1)[0]
# if list of classes_id is provide make sure that it's will be ignored

View File

@ -8,6 +8,7 @@ class EmotionEncoderConfig(SpeakerEncoderConfig):
"""Defines parameters for Speaker Encoder model."""
model: str = "emotion_encoder"
map_classid_to_classname: dict = None
def check_values(self):
super().check_values()