diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index ae03ab5d..08f81cc2 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -121,8 +121,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step): } dashboard_logger.train_epoch_stats(global_step, train_stats) figures = { - # FIXME: not constant - "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), } dashboard_logger.train_figures(global_step, figures) diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py index 65322258..f2db2f3f 100644 --- a/TTS/encoder/utils/visual.py +++ b/TTS/encoder/utils/visual.py @@ -29,14 +29,18 @@ colormap = ( ) -def plot_embeddings(embeddings, num_utter_per_class): - embeddings = embeddings[: 10 * num_utter_per_class] +def plot_embeddings(embeddings, num_classes_in_batch): + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + model = umap.UMAP() projection = model.fit_transform(embeddings) - num_speakers = embeddings.shape[0] // num_utter_per_class - ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) colors = [colormap[i] for i in ground_truth] - fig, ax = plt.subplots(figsize=(16, 10)) _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) plt.gca().set_aspect("equal", "datalim")