mirror of https://github.com/coqui-ai/TTS.git
Add evaluation encoder script
This commit is contained in:
parent
f811af7651
commit
0a06d1e67b
|
@ -53,10 +53,10 @@ speaker_manager = SpeakerManager(
|
|||
speaker_mapping = {}
|
||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||
if isinstance(wav_file, list):
|
||||
speaker_name = wav_file[2]
|
||||
class_name = wav_file[2]
|
||||
wav_file = wav_file[1]
|
||||
else:
|
||||
speaker_name = None
|
||||
class_name = None
|
||||
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
|
||||
|
@ -68,7 +68,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||
speaker_mapping[wav_file_name]["name"] = class_name
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd
|
||||
|
||||
if speaker_mapping:
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute the accuracy of the encoder.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
||||
parser.add_argument(
|
||||
"config_path",
|
||||
type=str,
|
||||
help="Path to model config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"config_dataset_path",
|
||||
type=str,
|
||||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
c_dataset = load_config(args.config_dataset_path)
|
||||
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
if speaker_manager.speaker_encoder_config.map_classid_to_classname is not None:
|
||||
map_classid_to_classname = speaker_manager.speaker_encoder_config.map_classid_to_classname
|
||||
else:
|
||||
map_classid_to_classname = None
|
||||
|
||||
# compute speaker embeddings
|
||||
class_acc_dict = {}
|
||||
|
||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||
if isinstance(wav_file, list):
|
||||
class_name = wav_file[2]
|
||||
wav_file = wav_file[1]
|
||||
else:
|
||||
class_name = None
|
||||
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
if speaker_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||
if args.use_cuda:
|
||||
embedding = embedding.cuda()
|
||||
|
||||
class_id = speaker_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
|
||||
predicted_label = map_classid_to_classname[str(class_id)]
|
||||
else:
|
||||
predicted_label = None
|
||||
|
||||
if class_name is not None and predicted_label is not None:
|
||||
is_equal = int(class_name == predicted_label)
|
||||
if class_name not in class_acc_dict:
|
||||
class_acc_dict[class_name] = [is_equal]
|
||||
else:
|
||||
class_acc_dict[class_name].append(is_equal)
|
||||
else:
|
||||
print("Error: class_name or/and predicted_label are None")
|
||||
exit()
|
||||
|
||||
acc_avg = 0
|
||||
for key in class_acc_dict:
|
||||
acc = sum(class_acc_dict[key])/len(class_acc_dict[key])
|
||||
print("Class", key, "ACC:", acc)
|
||||
acc_avg += acc
|
||||
|
||||
print("Average Acc:", acc_avg/len(class_acc_dict))
|
|
@ -189,6 +189,11 @@ class SoftmaxLoss(nn.Module):
|
|||
|
||||
return L
|
||||
|
||||
def inference(self, embedding):
|
||||
x = self.fc(embedding)
|
||||
activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
|
||||
class_id = torch.argmax(activations)
|
||||
return class_id
|
||||
|
||||
class SoftmaxAngleProtoLoss(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -182,8 +182,18 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# load the criterion for emotion classification
|
||||
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
|
||||
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
return criterion
|
||||
|
|
|
@ -5,7 +5,7 @@ from torch import nn
|
|||
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
from TTS.encoder.losses import SoftmaxAngleProtoLoss
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
|
@ -277,8 +277,18 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# load the criterion for emotion classification
|
||||
if "criterion" in state and config.loss == "softmaxproto" and config.model == "emotion_encoder" and config.map_classid_to_classname is not None:
|
||||
criterion = SoftmaxAngleProtoLoss(config.model_params["proj_dim"], len(config.map_classid_to_classname.keys()))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
return criterion
|
||||
|
|
|
@ -269,7 +269,7 @@ class SpeakerManager:
|
|||
"""
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
|
|
Loading…
Reference in New Issue