mirror of https://github.com/coqui-ai/TTS.git
Fix the bug in plot_embeddings
This commit is contained in:
parent
0a06d1e67b
commit
1c1684bdc5
|
@ -121,8 +121,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||||
}
|
}
|
||||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||||
figures = {
|
figures = {
|
||||||
# FIXME: not constant
|
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
|
||||||
}
|
}
|
||||||
dashboard_logger.train_figures(global_step, figures)
|
dashboard_logger.train_figures(global_step, figures)
|
||||||
|
|
||||||
|
|
|
@ -29,14 +29,18 @@ colormap = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def plot_embeddings(embeddings, num_utter_per_class):
|
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||||
embeddings = embeddings[: 10 * num_utter_per_class]
|
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||||
|
|
||||||
|
# if necessary get just the first 10 classes
|
||||||
|
if num_classes_in_batch > 10:
|
||||||
|
num_classes_in_batch = 10
|
||||||
|
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
|
||||||
|
|
||||||
model = umap.UMAP()
|
model = umap.UMAP()
|
||||||
projection = model.fit_transform(embeddings)
|
projection = model.fit_transform(embeddings)
|
||||||
num_speakers = embeddings.shape[0] // num_utter_per_class
|
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
|
||||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class)
|
|
||||||
colors = [colormap[i] for i in ground_truth]
|
colors = [colormap[i] for i in ground_truth]
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(16, 10))
|
fig, ax = plt.subplots(figsize=(16, 10))
|
||||||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||||
plt.gca().set_aspect("equal", "datalim")
|
plt.gca().set_aspect("equal", "datalim")
|
||||||
|
|
Loading…
Reference in New Issue