Fix the unitests

This commit is contained in:
Edresson Casanova 2022-03-10 16:22:33 -03:00
parent 50305215b3
commit 9c8b8201c3
6 changed files with 31 additions and 24 deletions

View File

@ -42,29 +42,31 @@ c_dataset = load_config(args.config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager(
encoder_manager = SpeakerManager(
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
)
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
# compute speaker embeddings
speaker_mapping = {}
for idx, wav_file in enumerate(tqdm(wav_files)):
if isinstance(wav_file, list):
class_name = wav_file[2]
wav_file = wav_file[1]
if isinstance(wav_file, dict):
class_name = wav_file[class_name_key]
wav_file = wav_file["audio_file"]
else:
class_name = None
wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
else:
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
# create speaker_mapping if target dataset is defined
speaker_mapping[wav_file_name] = {}
@ -81,5 +83,5 @@ if speaker_mapping:
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
# pylint: disable=W0212
speaker_manager._save_json(mapping_file_path, speaker_mapping)
encoder_manager._save_json(mapping_file_path, speaker_mapping)
print("Speaker embeddings saved at:", mapping_file_path)

View File

@ -38,33 +38,32 @@ c_dataset = load_config(args.config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager(
encoder_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
)
if speaker_manager.speaker_encoder_config.map_classid_to_classname is not None:
map_classid_to_classname = speaker_manager.speaker_encoder_config.map_classid_to_classname
else:
map_classid_to_classname = None
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
# compute speaker embeddings
class_acc_dict = {}
for idx, wav_file in enumerate(tqdm(wav_files)):
if isinstance(wav_file, list):
class_name = wav_file[2]
wav_file = wav_file[1]
if isinstance(wav_file, dict):
class_name = wav_file[class_name_key]
wav_file = wav_file["audio_file"]
else:
class_name = None
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
if speaker_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if args.use_cuda:
embedding = embedding.cuda()
class_id = speaker_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
predicted_label = map_classid_to_classname[str(class_id)]
else:
predicted_label = None

View File

@ -9,3 +9,4 @@ class EmotionEncoderConfig(BaseEncoderConfig):
model: str = "emotion_encoder"
map_classid_to_classname: dict = None
class_name_key: str = "emotion_name"

View File

@ -8,3 +8,4 @@ class SpeakerEncoderConfig(BaseEncoderConfig):
"""Defines parameters for Speaker Encoder model."""
model: str = "speaker_encoder"
class_name_key: str = "speaker_name"

View File

@ -66,7 +66,7 @@ class EncoderDataset(Dataset):
class_to_utters = {}
for item in self.items:
path_ = item["audio_file"]
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
class_name = item[self.config.class_name_key]
if class_name in class_to_utters.keys():
class_to_utters[class_name].append(path_)
else:

View File

@ -30,7 +30,7 @@ class BaseEncoder(nn.Module):
def __init__(self):
super(BaseEncoder, self).__init__()
def get_torch_mel_spectrogram_class(audio_config):
def get_torch_mel_spectrogram_class(self, audio_config):
return torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
@ -59,7 +59,7 @@ class BaseEncoder(nn.Module):
)
@torch.no_grad()
def inference(self, x, l2_norm=False):
def inference(self, x, l2_norm=True):
return self.forward(x, l2_norm)
@torch.no_grad()
@ -121,9 +121,13 @@ class BaseEncoder(nn.Module):
# load the criterion for restore_path
if criterion is not None and "criterion" in state:
criterion.load_state_dict(state["criterion"])
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
# instance and load the criterion for the encoder classifier in inference time
if eval and criterion is None and "criterion" in state and config.map_classid_to_classname is not None:
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"])