mirror of https://github.com/coqui-ai/TTS.git
Fix the bug in plot_embeddings
This commit is contained in:
parent
0a06d1e67b
commit
1c1684bdc5
|
@ -121,8 +121,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
|||
}
|
||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
# FIXME: not constant
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.train_figures(global_step, figures)
|
||||
|
||||
|
|
|
@ -29,14 +29,18 @@ colormap = (
|
|||
)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings, num_utter_per_class):
|
||||
embeddings = embeddings[: 10 * num_utter_per_class]
|
||||
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||
|
||||
# if necessary get just the first 10 classes
|
||||
if num_classes_in_batch > 10:
|
||||
num_classes_in_batch = 10
|
||||
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
|
||||
|
||||
model = umap.UMAP()
|
||||
projection = model.fit_transform(embeddings)
|
||||
num_speakers = embeddings.shape[0] // num_utter_per_class
|
||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class)
|
||||
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||
plt.gca().set_aspect("equal", "datalim")
|
||||
|
|
Loading…
Reference in New Issue