mirror of https://github.com/databricks/cli.git
Improved usability of `databricks auth login ... --configure-cluster` flow by displaying cluster type and runtime version (#956)
This PR adds selectors for Databricks-connect compatible clusters and SQL warehouses Tested in https://github.com/databricks/cli/pull/914
This commit is contained in:
parent
f111b0846e
commit
3284a8c56c
|
@ -8,9 +8,9 @@ import (
|
|||
"github.com/databricks/cli/libs/auth"
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/cli/libs/databrickscfg"
|
||||
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/databricks/databricks-sdk-go/service/compute"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
@ -28,6 +28,8 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
|
|||
return nil
|
||||
}
|
||||
|
||||
const minimalDbConnectVersion = "13.1"
|
||||
|
||||
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "login [HOST]",
|
||||
|
@ -95,19 +97,12 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
|||
return err
|
||||
}
|
||||
ctx := cmd.Context()
|
||||
|
||||
promptSpinner := cmdio.Spinner(ctx)
|
||||
promptSpinner <- "Loading list of clusters to select from"
|
||||
names, err := w.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{})
|
||||
close(promptSpinner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load clusters list. Original error: %w", err)
|
||||
}
|
||||
clusterId, err := cmdio.Select(ctx, names, "Choose cluster")
|
||||
clusterID, err := cfgpickers.AskForCluster(ctx, w,
|
||||
cfgpickers.WithDatabricksConnect(minimalDbConnectVersion))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.ClusterID = clusterId
|
||||
cfg.ClusterID = clusterID
|
||||
}
|
||||
|
||||
if profileName != "" {
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
package cfgpickers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/service/compute"
|
||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||
"github.com/fatih/color"
|
||||
"github.com/manifoldco/promptui"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
var minUcRuntime = canonicalVersion("v12.0")
|
||||
|
||||
var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`)
|
||||
var dbrSnapshotVersionRegex = regexp.MustCompile(`^(\d+)\.x-snapshot.*`)
|
||||
|
||||
func canonicalVersion(v string) string {
|
||||
return semver.Canonical("v" + strings.TrimPrefix(v, "v"))
|
||||
}
|
||||
|
||||
func GetRuntimeVersion(cluster compute.ClusterDetails) (string, bool) {
|
||||
match := dbrVersionRegex.FindStringSubmatch(cluster.SparkVersion)
|
||||
if len(match) < 1 {
|
||||
match = dbrSnapshotVersionRegex.FindStringSubmatch(cluster.SparkVersion)
|
||||
if len(match) > 1 {
|
||||
// we return 14.999 for 14.x-snapshot for semver.Compare() to work properly
|
||||
return fmt.Sprintf("%s.999", match[1]), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
return match[1], true
|
||||
}
|
||||
|
||||
func IsCompatibleWithUC(cluster compute.ClusterDetails, minVersion string) bool {
|
||||
minVersion = canonicalVersion(minVersion)
|
||||
if semver.Compare(minUcRuntime, minVersion) >= 0 {
|
||||
return false
|
||||
}
|
||||
runtimeVersion, ok := GetRuntimeVersion(cluster)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
clusterRuntime := canonicalVersion(runtimeVersion)
|
||||
if semver.Compare(minVersion, clusterRuntime) > 0 {
|
||||
return false
|
||||
}
|
||||
switch cluster.DataSecurityMode {
|
||||
case compute.DataSecurityModeUserIsolation, compute.DataSecurityModeSingleUser:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var ErrNoCompatibleClusters = errors.New("no compatible clusters found")
|
||||
|
||||
type compatibleCluster struct {
|
||||
compute.ClusterDetails
|
||||
versionName string
|
||||
}
|
||||
|
||||
func (v compatibleCluster) Access() string {
|
||||
switch v.DataSecurityMode {
|
||||
case compute.DataSecurityModeUserIsolation:
|
||||
return "Shared"
|
||||
case compute.DataSecurityModeSingleUser:
|
||||
return "Assigned"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (v compatibleCluster) Runtime() string {
|
||||
runtime, _, _ := strings.Cut(v.versionName, " (")
|
||||
return runtime
|
||||
}
|
||||
|
||||
func (v compatibleCluster) State() string {
|
||||
state := v.ClusterDetails.State
|
||||
switch state {
|
||||
case compute.StateRunning, compute.StateResizing:
|
||||
return color.GreenString(state.String())
|
||||
case compute.StateError, compute.StateTerminated, compute.StateTerminating, compute.StateUnknown:
|
||||
return color.RedString(state.String())
|
||||
default:
|
||||
return color.BlueString(state.String())
|
||||
}
|
||||
}
|
||||
|
||||
type clusterFilter func(cluster *compute.ClusterDetails, me *iam.User) bool
|
||||
|
||||
func WithDatabricksConnect(minVersion string) func(*compute.ClusterDetails, *iam.User) bool {
|
||||
return func(cluster *compute.ClusterDetails, me *iam.User) bool {
|
||||
if !IsCompatibleWithUC(*cluster, minVersion) {
|
||||
return false
|
||||
}
|
||||
switch cluster.ClusterSource {
|
||||
case compute.ClusterSourceJob,
|
||||
compute.ClusterSourceModels,
|
||||
compute.ClusterSourcePipeline,
|
||||
compute.ClusterSourcePipelineMaintenance,
|
||||
compute.ClusterSourceSql:
|
||||
// only UI and API clusters are usable for DBConnect.
|
||||
// `CanUseClient: "NOTEBOOKS"`` didn't seem to have an effect.
|
||||
return false
|
||||
}
|
||||
if cluster.SingleUserName != "" && cluster.SingleUserName != me.UserName {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func loadInteractiveClusters(ctx context.Context, w *databricks.WorkspaceClient, filters []clusterFilter) ([]compatibleCluster, error) {
|
||||
promptSpinner := cmdio.Spinner(ctx)
|
||||
promptSpinner <- "Loading list of clusters to select from"
|
||||
defer close(promptSpinner)
|
||||
all, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{
|
||||
CanUseClient: "NOTEBOOKS",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list clusters: %w", err)
|
||||
}
|
||||
me, err := w.CurrentUser.Me(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("current user: %w", err)
|
||||
}
|
||||
versions := map[string]string{}
|
||||
sv, err := w.Clusters.SparkVersions(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list runtime versions: %w", err)
|
||||
}
|
||||
for _, v := range sv.Versions {
|
||||
versions[v.Key] = v.Name
|
||||
}
|
||||
var compatible []compatibleCluster
|
||||
for _, cluster := range all {
|
||||
var skip bool
|
||||
for _, filter := range filters {
|
||||
if !filter(&cluster, me) {
|
||||
skip = true
|
||||
}
|
||||
}
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
compatible = append(compatible, compatibleCluster{
|
||||
ClusterDetails: cluster,
|
||||
versionName: versions[cluster.SparkVersion],
|
||||
})
|
||||
}
|
||||
return compatible, nil
|
||||
}
|
||||
|
||||
func AskForCluster(ctx context.Context, w *databricks.WorkspaceClient, filters ...clusterFilter) (string, error) {
|
||||
compatible, err := loadInteractiveClusters(ctx, w, filters)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load: %w", err)
|
||||
}
|
||||
if len(compatible) == 0 {
|
||||
return "", ErrNoCompatibleClusters
|
||||
}
|
||||
if len(compatible) == 1 {
|
||||
return compatible[0].ClusterId, nil
|
||||
}
|
||||
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
|
||||
Label: "Choose compatible cluster",
|
||||
Items: compatible,
|
||||
Searcher: func(input string, idx int) bool {
|
||||
lower := strings.ToLower(compatible[idx].ClusterName)
|
||||
return strings.Contains(lower, input)
|
||||
},
|
||||
StartInSearchMode: true,
|
||||
Templates: &promptui.SelectTemplates{
|
||||
Label: "{{.ClusterName | faint}}",
|
||||
Active: `{{.ClusterName | bold}} ({{.State}} {{.Access}} Runtime {{.Runtime}}) ({{.ClusterId | faint}})`,
|
||||
Inactive: `{{.ClusterName}} ({{.State}} {{.Access}} Runtime {{.Runtime}})`,
|
||||
Selected: `{{ "Configured cluster" | faint }}: {{ .ClusterName | bold }} ({{.ClusterId | faint}})`,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return compatible[i].ClusterId, nil
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
package cfgpickers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/cli/libs/flags"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/qa"
|
||||
"github.com/databricks/databricks-sdk-go/service/compute"
|
||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCompatible(t *testing.T) {
|
||||
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "13.2.x-aarch64-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
||||
}, "13.0"))
|
||||
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "13.2.x-aarch64-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeNone,
|
||||
}, "13.0"))
|
||||
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "9.1.x-photon-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeNone,
|
||||
}, "13.0"))
|
||||
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "9.1.x-photon-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeNone,
|
||||
}, "10.0"))
|
||||
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "custom-9.1.x-photon-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeNone,
|
||||
}, "14.0"))
|
||||
}
|
||||
|
||||
func TestIsCompatibleWithSnapshots(t *testing.T) {
|
||||
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
|
||||
SparkVersion: "14.x-snapshot-cpu-ml-scala2.12",
|
||||
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
||||
}, "14.0"))
|
||||
}
|
||||
|
||||
func TestFirstCompatibleCluster(t *testing.T) {
|
||||
cfg, server := qa.HTTPFixtures{
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
|
||||
Response: compute.ListClustersResponse{
|
||||
Clusters: []compute.ClusterDetails{
|
||||
{
|
||||
ClusterId: "abc-id",
|
||||
ClusterName: "first shared",
|
||||
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
||||
SparkVersion: "12.2.x-whatever",
|
||||
State: compute.StateRunning,
|
||||
},
|
||||
{
|
||||
ClusterId: "bcd-id",
|
||||
ClusterName: "second personal",
|
||||
DataSecurityMode: compute.DataSecurityModeSingleUser,
|
||||
SparkVersion: "14.5.x-whatever",
|
||||
State: compute.StateRunning,
|
||||
SingleUserName: "serge",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/preview/scim/v2/Me",
|
||||
Response: iam.User{
|
||||
UserName: "serge",
|
||||
},
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/clusters/spark-versions",
|
||||
Response: compute.GetSparkVersionsResponse{
|
||||
Versions: []compute.SparkVersion{
|
||||
{
|
||||
Key: "14.5.x-whatever",
|
||||
Name: "14.5 (Awesome)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}.Config(t)
|
||||
defer server.Close()
|
||||
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
|
||||
clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "bcd-id", clusterID)
|
||||
}
|
||||
|
||||
func TestNoCompatibleClusters(t *testing.T) {
|
||||
cfg, server := qa.HTTPFixtures{
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
|
||||
Response: compute.ListClustersResponse{
|
||||
Clusters: []compute.ClusterDetails{
|
||||
{
|
||||
ClusterId: "abc-id",
|
||||
ClusterName: "first shared",
|
||||
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
||||
SparkVersion: "12.2.x-whatever",
|
||||
State: compute.StateRunning,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/preview/scim/v2/Me",
|
||||
Response: iam.User{
|
||||
UserName: "serge",
|
||||
},
|
||||
},
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/clusters/spark-versions",
|
||||
Response: compute.GetSparkVersionsResponse{
|
||||
Versions: []compute.SparkVersion{
|
||||
{
|
||||
Key: "14.5.x-whatever",
|
||||
Name: "14.5 (Awesome)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}.Config(t)
|
||||
defer server.Close()
|
||||
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
|
||||
_, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
|
||||
require.Equal(t, ErrNoCompatibleClusters, err)
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package cfgpickers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/service/sql"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
var ErrNoCompatibleWarehouses = errors.New("no compatible warehouses")
|
||||
|
||||
type warehouseFilter func(sql.EndpointInfo) bool
|
||||
|
||||
func WithWarehouseTypes(types ...sql.EndpointInfoWarehouseType) func(sql.EndpointInfo) bool {
|
||||
allowed := map[sql.EndpointInfoWarehouseType]bool{}
|
||||
for _, v := range types {
|
||||
allowed[v] = true
|
||||
}
|
||||
return func(ei sql.EndpointInfo) bool {
|
||||
return allowed[ei.WarehouseType]
|
||||
}
|
||||
}
|
||||
|
||||
func AskForWarehouse(ctx context.Context, w *databricks.WorkspaceClient, filters ...warehouseFilter) (string, error) {
|
||||
all, err := w.Warehouses.ListAll(ctx, sql.ListWarehousesRequest{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("list warehouses: %w", err)
|
||||
}
|
||||
var lastWarehouseID string
|
||||
names := map[string]string{}
|
||||
for _, warehouse := range all {
|
||||
var skip bool
|
||||
for _, filter := range filters {
|
||||
if !filter(warehouse) {
|
||||
skip = true
|
||||
}
|
||||
}
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
var state string
|
||||
switch warehouse.State {
|
||||
case sql.StateRunning:
|
||||
state = color.GreenString(warehouse.State.String())
|
||||
case sql.StateStopped, sql.StateDeleted, sql.StateStopping, sql.StateDeleting:
|
||||
state = color.RedString(warehouse.State.String())
|
||||
default:
|
||||
state = color.BlueString(warehouse.State.String())
|
||||
}
|
||||
visibleTouser := fmt.Sprintf("%s (%s %s)", warehouse.Name, state, warehouse.WarehouseType)
|
||||
names[visibleTouser] = warehouse.Id
|
||||
lastWarehouseID = warehouse.Id
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return "", ErrNoCompatibleWarehouses
|
||||
}
|
||||
if len(names) == 1 {
|
||||
return lastWarehouseID, nil
|
||||
}
|
||||
return cmdio.Select(ctx, names, "Choose SQL Warehouse")
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package cfgpickers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/qa"
|
||||
"github.com/databricks/databricks-sdk-go/service/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFirstCompatibleWarehouse(t *testing.T) {
|
||||
cfg, server := qa.HTTPFixtures{
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/sql/warehouses?",
|
||||
Response: sql.ListWarehousesResponse{
|
||||
Warehouses: []sql.EndpointInfo{
|
||||
{
|
||||
Id: "efg-id",
|
||||
Name: "First PRO Warehouse",
|
||||
WarehouseType: sql.EndpointInfoWarehouseTypePro,
|
||||
},
|
||||
{
|
||||
Id: "ghe-id",
|
||||
Name: "Second UNKNOWN Warehouse",
|
||||
WarehouseType: sql.EndpointInfoWarehouseTypeTypeUnspecified,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}.Config(t)
|
||||
defer server.Close()
|
||||
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
|
||||
|
||||
ctx := context.Background()
|
||||
clusterID, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "efg-id", clusterID)
|
||||
}
|
||||
|
||||
func TestNoCompatibleWarehouses(t *testing.T) {
|
||||
cfg, server := qa.HTTPFixtures{
|
||||
{
|
||||
Method: "GET",
|
||||
Resource: "/api/2.0/sql/warehouses?",
|
||||
Response: sql.ListWarehousesResponse{
|
||||
Warehouses: []sql.EndpointInfo{
|
||||
{
|
||||
Id: "efg-id",
|
||||
Name: "...",
|
||||
WarehouseType: sql.EndpointInfoWarehouseTypeClassic,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}.Config(t)
|
||||
defer server.Close()
|
||||
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro))
|
||||
assert.Equal(t, ErrNoCompatibleWarehouses, err)
|
||||
}
|
Loading…
Reference in New Issue