build: move umap-learn into optional notebook dependencies

Except for notebooks, it's only used to show embedding plots during speaker
encoder training, in which case a warning is now shown to install it.
This commit is contained in:
Enno Hermann 2024-06-26 23:53:17 +02:00
parent ff2cd5c97d
commit 59ef28d708
3 changed files with 14 additions and 8 deletions

View File

@ -6,6 +6,7 @@ import os
import sys import sys
import time import time
import traceback import traceback
import warnings
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -116,11 +117,14 @@ def evaluation(model, criterion, data_loader, global_step):
eval_avg_loss = eval_loss / len(data_loader) eval_avg_loss = eval_loss / len(data_loader)
# save stats # save stats
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
# plot the last batch in the evaluation try:
figures = { # plot the last batch in the evaluation
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), figures = {
} "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
dashboard_logger.eval_figures(global_step, figures) }
dashboard_logger.eval_figures(global_step, figures)
except ImportError:
warnings.warn("Install the `umap-learn` package to see embedding plots.")
return eval_avg_loss return eval_avg_loss

View File

@ -1,7 +1,6 @@
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import umap
matplotlib.use("Agg") matplotlib.use("Agg")
@ -30,6 +29,10 @@ colormap = (
def plot_embeddings(embeddings, num_classes_in_batch): def plot_embeddings(embeddings, num_classes_in_batch):
try:
import umap
except ImportError as e:
raise ImportError("Package not installed: umap-learn") from e
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
# if necessary get just the first 10 classes # if necessary get just the first 10 classes

View File

@ -58,8 +58,6 @@ dependencies = [
"packaging>=23.1", "packaging>=23.1",
# Inference # Inference
"pysbd>=0.3.4", "pysbd>=0.3.4",
# Notebooks
"umap-learn>=0.5.1",
# Training # Training
"matplotlib>=3.7.0", "matplotlib>=3.7.0",
# Coqui stack # Coqui stack
@ -100,6 +98,7 @@ docs = [
notebooks = [ notebooks = [
"bokeh==1.4.0", "bokeh==1.4.0",
"pandas>=1.4,<2.0", "pandas>=1.4,<2.0",
"umap-learn>=0.5.1",
] ]
# For running the TTS server # For running the TTS server
server = ["flask>=2.0.1"] server = ["flask>=2.0.1"]