Fix the bug in plot_embeddings

This commit is contained in:
Edresson Casanova 2022-03-02 18:39:20 -03:00
parent 0a06d1e67b
commit 1c1684bdc5
2 changed files with 10 additions and 7 deletions

View File

@ -121,8 +121,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
}
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
# FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.train_figures(global_step, figures)

View File

@ -29,14 +29,18 @@ colormap = (
)
def plot_embeddings(embeddings, num_utter_per_class):
embeddings = embeddings[: 10 * num_utter_per_class]
def plot_embeddings(embeddings, num_classes_in_batch):
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
# if necessary get just the first 10 classes
if num_classes_in_batch > 10:
num_classes_in_batch = 10
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
model = umap.UMAP()
projection = model.fit_transform(embeddings)
num_speakers = embeddings.shape[0] // num_utter_per_class
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class)
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
colors = [colormap[i] for i in ground_truth]
fig, ax = plt.subplots(figsize=(16, 10))
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
plt.gca().set_aspect("equal", "datalim")