mirror of https://github.com/coqui-ai/TTS.git
build: move umap-learn into optional notebook dependencies
Except for notebooks, it's only used to show embedding plots during speaker encoder training, in which case a warning is now shown to install it.
This commit is contained in:
parent
ff2cd5c97d
commit
59ef28d708
|
@ -6,6 +6,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -116,11 +117,14 @@ def evaluation(model, criterion, data_loader, global_step):
|
||||||
eval_avg_loss = eval_loss / len(data_loader)
|
eval_avg_loss = eval_loss / len(data_loader)
|
||||||
# save stats
|
# save stats
|
||||||
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
|
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
|
||||||
# plot the last batch in the evaluation
|
try:
|
||||||
figures = {
|
# plot the last batch in the evaluation
|
||||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
figures = {
|
||||||
}
|
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||||
dashboard_logger.eval_figures(global_step, figures)
|
}
|
||||||
|
dashboard_logger.eval_figures(global_step, figures)
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn("Install the `umap-learn` package to see embedding plots.")
|
||||||
return eval_avg_loss
|
return eval_avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import umap
|
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
@ -30,6 +29,10 @@ colormap = (
|
||||||
|
|
||||||
|
|
||||||
def plot_embeddings(embeddings, num_classes_in_batch):
|
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||||
|
try:
|
||||||
|
import umap
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Package not installed: umap-learn") from e
|
||||||
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||||
|
|
||||||
# if necessary get just the first 10 classes
|
# if necessary get just the first 10 classes
|
||||||
|
|
|
@ -58,8 +58,6 @@ dependencies = [
|
||||||
"packaging>=23.1",
|
"packaging>=23.1",
|
||||||
# Inference
|
# Inference
|
||||||
"pysbd>=0.3.4",
|
"pysbd>=0.3.4",
|
||||||
# Notebooks
|
|
||||||
"umap-learn>=0.5.1",
|
|
||||||
# Training
|
# Training
|
||||||
"matplotlib>=3.7.0",
|
"matplotlib>=3.7.0",
|
||||||
# Coqui stack
|
# Coqui stack
|
||||||
|
@ -100,6 +98,7 @@ docs = [
|
||||||
notebooks = [
|
notebooks = [
|
||||||
"bokeh==1.4.0",
|
"bokeh==1.4.0",
|
||||||
"pandas>=1.4,<2.0",
|
"pandas>=1.4,<2.0",
|
||||||
|
"umap-learn>=0.5.1",
|
||||||
]
|
]
|
||||||
# For running the TTS server
|
# For running the TTS server
|
||||||
server = ["flask>=2.0.1"]
|
server = ["flask>=2.0.1"]
|
||||||
|
|
Loading…
Reference in New Issue