mirror of https://github.com/coqui-ai/TTS.git
build: move umap-learn into optional notebook dependencies
Except for notebooks, it's only used to show embedding plots during speaker encoder training, in which case a warning is now shown to install it.
This commit is contained in:
parent
ff2cd5c97d
commit
59ef28d708
|
@ -6,6 +6,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -116,11 +117,14 @@ def evaluation(model, criterion, data_loader, global_step):
|
|||
eval_avg_loss = eval_loss / len(data_loader)
|
||||
# save stats
|
||||
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
|
||||
# plot the last batch in the evaluation
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.eval_figures(global_step, figures)
|
||||
try:
|
||||
# plot the last batch in the evaluation
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.eval_figures(global_step, figures)
|
||||
except ImportError:
|
||||
warnings.warn("Install the `umap-learn` package to see embedding plots.")
|
||||
return eval_avg_loss
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import umap
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
@ -30,6 +29,10 @@ colormap = (
|
|||
|
||||
|
||||
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||
try:
|
||||
import umap
|
||||
except ImportError as e:
|
||||
raise ImportError("Package not installed: umap-learn") from e
|
||||
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||
|
||||
# if necessary get just the first 10 classes
|
||||
|
|
|
@ -58,8 +58,6 @@ dependencies = [
|
|||
"packaging>=23.1",
|
||||
# Inference
|
||||
"pysbd>=0.3.4",
|
||||
# Notebooks
|
||||
"umap-learn>=0.5.1",
|
||||
# Training
|
||||
"matplotlib>=3.7.0",
|
||||
# Coqui stack
|
||||
|
@ -100,6 +98,7 @@ docs = [
|
|||
notebooks = [
|
||||
"bokeh==1.4.0",
|
||||
"pandas>=1.4,<2.0",
|
||||
"umap-learn>=0.5.1",
|
||||
]
|
||||
# For running the TTS server
|
||||
server = ["flask>=2.0.1"]
|
||||
|
|
Loading…
Reference in New Issue