mirror of https://github.com/coqui-ai/TTS.git
Add class map in emotion config
This commit is contained in:
parent
854c887764
commit
33ac13e44e
|
@ -18,7 +18,7 @@ from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec, copy_model_files
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
|
@ -56,8 +56,9 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=c.num_loader_workers,
|
num_workers=c.num_loader_workers,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
)
|
)
|
||||||
return loader, dataset.get_num_classes()
|
|
||||||
|
return loader, dataset.get_num_classes(), dataset.get_map_classid_to_classname()
|
||||||
|
|
||||||
|
|
||||||
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||||
|
@ -160,7 +161,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
||||||
|
|
||||||
data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True)
|
data_loader, num_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
|
||||||
|
|
||||||
if c.loss == "ge2e":
|
if c.loss == "ge2e":
|
||||||
criterion = GE2ELoss(loss_method="softmax")
|
criterion = GE2ELoss(loss_method="softmax")
|
||||||
|
@ -168,6 +169,11 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
criterion = AngleProtoLoss()
|
criterion = AngleProtoLoss()
|
||||||
elif c.loss == "softmaxproto":
|
elif c.loss == "softmaxproto":
|
||||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||||
|
if c.model == "emotion_encoder":
|
||||||
|
# update config with the class map
|
||||||
|
c.map_classid_to_classname = map_classid_to_classname
|
||||||
|
copy_model_files(c, OUT_PATH)
|
||||||
|
print(OUT_PATH)
|
||||||
else:
|
else:
|
||||||
raise Exception("The %s not is a loss supported" % c.loss)
|
raise Exception("The %s not is a loss supported" % c.loss)
|
||||||
|
|
||||||
|
|
|
@ -102,6 +102,9 @@ class EncoderDataset(Dataset):
|
||||||
def get_num_classes(self):
|
def get_num_classes(self):
|
||||||
return len(self.classes)
|
return len(self.classes)
|
||||||
|
|
||||||
|
def get_map_classid_to_classname(self):
|
||||||
|
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
||||||
|
|
||||||
def __sample_class(self, ignore_classes=None):
|
def __sample_class(self, ignore_classes=None):
|
||||||
class_name = random.sample(self.classes, 1)[0]
|
class_name = random.sample(self.classes, 1)[0]
|
||||||
# if list of classes_id is provide make sure that it's will be ignored
|
# if list of classes_id is provide make sure that it's will be ignored
|
||||||
|
|
|
@ -8,6 +8,7 @@ class EmotionEncoderConfig(SpeakerEncoderConfig):
|
||||||
"""Defines parameters for Speaker Encoder model."""
|
"""Defines parameters for Speaker Encoder model."""
|
||||||
|
|
||||||
model: str = "emotion_encoder"
|
model: str = "emotion_encoder"
|
||||||
|
map_classid_to_classname: dict = None
|
||||||
|
|
||||||
def check_values(self):
|
def check_values(self):
|
||||||
super().check_values()
|
super().check_values()
|
||||||
|
|
Loading…
Reference in New Issue