diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 50817154..917994a2 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -53,10 +53,10 @@ speaker_manager = SpeakerManager( speaker_mapping = {} for idx, wav_file in enumerate(tqdm(wav_files)): if isinstance(wav_file, list): - speaker_name = wav_file[2] + class_name = wav_file[2] wav_file = wav_file[1] else: - speaker_name = None + class_name = None wav_file_name = os.path.basename(wav_file) if args.old_file is not None and wav_file_name in speaker_manager.clip_ids: @@ -68,7 +68,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)): # create speaker_mapping if target dataset is defined speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]["name"] = speaker_name + speaker_mapping[wav_file_name]["name"] = class_name speaker_mapping[wav_file_name]["embedding"] = embedd if speaker_mapping: diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py new file mode 100644 index 00000000..8acc8ffc --- /dev/null +++ b/TTS/bin/eval_encoder.py @@ -0,0 +1,89 @@ +import argparse +import os +import torch +from argparse import RawTextHelpFormatter + +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager + +parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, +) +parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") +parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", +) + +parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", +) +parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) +parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + +args = parser.parse_args() + +c_dataset = load_config(args.config_dataset_path) + +meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) +wav_files = meta_data_train + meta_data_eval + +speaker_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda +) + +if speaker_manager.speaker_encoder_config.map_classid_to_classname is not None: + map_classid_to_classname = speaker_manager.speaker_encoder_config.map_classid_to_classname +else: + map_classid_to_classname = None + +# compute speaker embeddings +class_acc_dict = {} + +for idx, wav_file in enumerate(tqdm(wav_files)): + if isinstance(wav_file, list): + class_name = wav_file[2] + wav_file = wav_file[1] + else: + class_name = None + + # extract the embedding + embedd = speaker_manager.compute_d_vector_from_clip(wav_file) + if speaker_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if args.use_cuda: + embedding = embedding.cuda() + + class_id = speaker_manager.speaker_encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + print("Error: class_name or/and predicted_label are None") + exit() + +acc_avg = 0 +for key in class_acc_dict: + acc = sum(class_acc_dict[key])/len(class_acc_dict[key]) + print("Class", key, "ACC:", acc) + acc_avg += acc + +print("Average Acc:", acc_avg/len(class_acc_dict)) diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py index 8ba917b7..de65d8d6 100644 --- a/TTS/encoder/losses.py +++ b/TTS/encoder/losses.py @@ -189,6 +189,11 @@ class SoftmaxLoss(nn.Module): return L + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id class SoftmaxAngleProtoLoss(nn.Module): """ diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py index dfba53cc..6144a9b4 100644 --- a/TTS/encoder/models/lstm.py +++ b/TTS/encoder/models/lstm.py @@ -182,8 +182,18 @@ class LSTMSpeakerEncoder(nn.Module): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) + # load the criterion for emotion classification + if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None: + criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys())) + criterion.load_state_dict(state["criterion"]) + else: + criterion = None + if use_cuda: self.cuda() + if criterion is not None: + criterion = criterion.cuda() if eval: self.eval() assert not self.training + return criterion diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index a799fc52..65da2ea1 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -5,7 +5,7 @@ from torch import nn # from TTS.utils.audio import TorchSTFT from TTS.utils.io import load_fsspec - +from TTS.encoder.losses import SoftmaxAngleProtoLoss class PreEmphasis(nn.Module): def __init__(self, coefficient=0.97): @@ -277,8 +277,18 @@ class ResNetSpeakerEncoder(nn.Module): def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) + # load the criterion for emotion classification + if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None: + criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys())) + criterion.load_state_dict(state["criterion"]) + else: + criterion = None + if use_cuda: self.cuda() + if criterion is not None: + criterion = criterion.cuda() if eval: self.eval() assert not self.training + return criterion diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index c2da7eb5..1a5da94a 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -269,7 +269,7 @@ class SpeakerManager: """ self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config) - self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) + self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list: