mirror of https://github.com/coqui-ai/TTS.git
Add class map in emotion config
This commit is contained in:
parent
854c887764
commit
33ac13e44e
|
@ -18,7 +18,7 @@ from TTS.encoder.utils.visual import plot_embeddings
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.io import load_fsspec, copy_model_files
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
|
@ -56,8 +56,9 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
shuffle=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
return loader, dataset.get_num_classes()
|
||||
)
|
||||
|
||||
return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname()
|
||||
|
||||
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||
|
@ -160,7 +161,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
||||
|
||||
data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True)
|
||||
data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
|
@ -168,6 +169,11 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
if c.model == "emotion_encoder":
|
||||
# update config with the class map
|
||||
c.map_classid_to_classname = map_classid_to_classname
|
||||
copy_model_files(c, OUT_PATH)
|
||||
print(OUT_PATH)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
|
|
|
@ -102,6 +102,9 @@ class EncoderDataset(Dataset):
|
|||
def get_num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def get_map_classid_to_classname(self):
|
||||
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
||||
|
||||
def __sample_class(self, ignore_classes=None):
|
||||
class_name = random.sample(self.classes, 1)[0]
|
||||
# if list of classes_id is provide make sure that it's will be ignored
|
||||
|
|
|
@ -8,6 +8,7 @@ class EmotionEncoderConfig(SpeakerEncoderConfig):
|
|||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "emotion_encoder"
|
||||
map_classid_to_classname: dict = None
|
||||
|
||||
def check_values(self):
|
||||
super().check_values()
|
||||
|
|
Loading…
Reference in New Issue