From 3284a8c56c65cfa53cd098874f99b2b76714353e Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:38:45 +0100 Subject: [PATCH] Improved usability of `databricks auth login ... --configure-cluster` flow by displaying cluster type and runtime version (#956) This PR adds selectors for Databricks-connect compatible clusters and SQL warehouses Tested in https://github.com/databricks/cli/pull/914 --- cmd/auth/login.go | 17 +- libs/databrickscfg/cfgpickers/clusters.go | 192 ++++++++++++++++++ .../databrickscfg/cfgpickers/clusters_test.go | 146 +++++++++++++ libs/databrickscfg/cfgpickers/warehouses.go | 65 ++++++ .../cfgpickers/warehouses_test.go | 66 ++++++ 5 files changed, 475 insertions(+), 11 deletions(-) create mode 100644 libs/databrickscfg/cfgpickers/clusters.go create mode 100644 libs/databrickscfg/cfgpickers/clusters_test.go create mode 100644 libs/databrickscfg/cfgpickers/warehouses.go create mode 100644 libs/databrickscfg/cfgpickers/warehouses_test.go diff --git a/cmd/auth/login.go b/cmd/auth/login.go index c2b821b6..28e0025d 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -8,9 +8,9 @@ import ( "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/cli/libs/databrickscfg/cfgpickers" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/service/compute" "github.com/spf13/cobra" ) @@ -28,6 +28,8 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg return nil } +const minimalDbConnectVersion = "13.1" + func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { cmd := &cobra.Command{ Use: "login [HOST]", @@ -95,19 +97,12 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { return err } ctx := cmd.Context() - - promptSpinner := cmdio.Spinner(ctx) - promptSpinner <- "Loading list of clusters to select from" - names, err := w.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{}) - close(promptSpinner) - if err != nil { - return fmt.Errorf("failed to load clusters list. Original error: %w", err) - } - clusterId, err := cmdio.Select(ctx, names, "Choose cluster") + clusterID, err := cfgpickers.AskForCluster(ctx, w, + cfgpickers.WithDatabricksConnect(minimalDbConnectVersion)) if err != nil { return err } - cfg.ClusterID = clusterId + cfg.ClusterID = clusterID } if profileName != "" { diff --git a/libs/databrickscfg/cfgpickers/clusters.go b/libs/databrickscfg/cfgpickers/clusters.go new file mode 100644 index 00000000..ac037698 --- /dev/null +++ b/libs/databrickscfg/cfgpickers/clusters.go @@ -0,0 +1,192 @@ +package cfgpickers + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/fatih/color" + "github.com/manifoldco/promptui" + "golang.org/x/mod/semver" +) + +var minUcRuntime = canonicalVersion("v12.0") + +var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`) +var dbrSnapshotVersionRegex = regexp.MustCompile(`^(\d+)\.x-snapshot.*`) + +func canonicalVersion(v string) string { + return semver.Canonical("v" + strings.TrimPrefix(v, "v")) +} + +func GetRuntimeVersion(cluster compute.ClusterDetails) (string, bool) { + match := dbrVersionRegex.FindStringSubmatch(cluster.SparkVersion) + if len(match) < 1 { + match = dbrSnapshotVersionRegex.FindStringSubmatch(cluster.SparkVersion) + if len(match) > 1 { + // we return 14.999 for 14.x-snapshot for semver.Compare() to work properly + return fmt.Sprintf("%s.999", match[1]), true + } + return "", false + } + return match[1], true +} + +func IsCompatibleWithUC(cluster compute.ClusterDetails, minVersion string) bool { + minVersion = canonicalVersion(minVersion) + if semver.Compare(minUcRuntime, minVersion) >= 0 { + return false + } + runtimeVersion, ok := GetRuntimeVersion(cluster) + if !ok { + return false + } + clusterRuntime := canonicalVersion(runtimeVersion) + if semver.Compare(minVersion, clusterRuntime) > 0 { + return false + } + switch cluster.DataSecurityMode { + case compute.DataSecurityModeUserIsolation, compute.DataSecurityModeSingleUser: + return true + default: + return false + } +} + +var ErrNoCompatibleClusters = errors.New("no compatible clusters found") + +type compatibleCluster struct { + compute.ClusterDetails + versionName string +} + +func (v compatibleCluster) Access() string { + switch v.DataSecurityMode { + case compute.DataSecurityModeUserIsolation: + return "Shared" + case compute.DataSecurityModeSingleUser: + return "Assigned" + default: + return "Unknown" + } +} + +func (v compatibleCluster) Runtime() string { + runtime, _, _ := strings.Cut(v.versionName, " (") + return runtime +} + +func (v compatibleCluster) State() string { + state := v.ClusterDetails.State + switch state { + case compute.StateRunning, compute.StateResizing: + return color.GreenString(state.String()) + case compute.StateError, compute.StateTerminated, compute.StateTerminating, compute.StateUnknown: + return color.RedString(state.String()) + default: + return color.BlueString(state.String()) + } +} + +type clusterFilter func(cluster *compute.ClusterDetails, me *iam.User) bool + +func WithDatabricksConnect(minVersion string) func(*compute.ClusterDetails, *iam.User) bool { + return func(cluster *compute.ClusterDetails, me *iam.User) bool { + if !IsCompatibleWithUC(*cluster, minVersion) { + return false + } + switch cluster.ClusterSource { + case compute.ClusterSourceJob, + compute.ClusterSourceModels, + compute.ClusterSourcePipeline, + compute.ClusterSourcePipelineMaintenance, + compute.ClusterSourceSql: + // only UI and API clusters are usable for DBConnect. + // `CanUseClient: "NOTEBOOKS"`` didn't seem to have an effect. + return false + } + if cluster.SingleUserName != "" && cluster.SingleUserName != me.UserName { + return false + } + return true + } +} + +func loadInteractiveClusters(ctx context.Context, w *databricks.WorkspaceClient, filters []clusterFilter) ([]compatibleCluster, error) { + promptSpinner := cmdio.Spinner(ctx) + promptSpinner <- "Loading list of clusters to select from" + defer close(promptSpinner) + all, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{ + CanUseClient: "NOTEBOOKS", + }) + if err != nil { + return nil, fmt.Errorf("list clusters: %w", err) + } + me, err := w.CurrentUser.Me(ctx) + if err != nil { + return nil, fmt.Errorf("current user: %w", err) + } + versions := map[string]string{} + sv, err := w.Clusters.SparkVersions(ctx) + if err != nil { + return nil, fmt.Errorf("list runtime versions: %w", err) + } + for _, v := range sv.Versions { + versions[v.Key] = v.Name + } + var compatible []compatibleCluster + for _, cluster := range all { + var skip bool + for _, filter := range filters { + if !filter(&cluster, me) { + skip = true + } + } + if skip { + continue + } + compatible = append(compatible, compatibleCluster{ + ClusterDetails: cluster, + versionName: versions[cluster.SparkVersion], + }) + } + return compatible, nil +} + +func AskForCluster(ctx context.Context, w *databricks.WorkspaceClient, filters ...clusterFilter) (string, error) { + compatible, err := loadInteractiveClusters(ctx, w, filters) + if err != nil { + return "", fmt.Errorf("load: %w", err) + } + if len(compatible) == 0 { + return "", ErrNoCompatibleClusters + } + if len(compatible) == 1 { + return compatible[0].ClusterId, nil + } + i, _, err := cmdio.RunSelect(ctx, &promptui.Select{ + Label: "Choose compatible cluster", + Items: compatible, + Searcher: func(input string, idx int) bool { + lower := strings.ToLower(compatible[idx].ClusterName) + return strings.Contains(lower, input) + }, + StartInSearchMode: true, + Templates: &promptui.SelectTemplates{ + Label: "{{.ClusterName | faint}}", + Active: `{{.ClusterName | bold}} ({{.State}} {{.Access}} Runtime {{.Runtime}}) ({{.ClusterId | faint}})`, + Inactive: `{{.ClusterName}} ({{.State}} {{.Access}} Runtime {{.Runtime}})`, + Selected: `{{ "Configured cluster" | faint }}: {{ .ClusterName | bold }} ({{.ClusterId | faint}})`, + }, + }) + if err != nil { + return "", err + } + return compatible[i].ClusterId, nil +} diff --git a/libs/databrickscfg/cfgpickers/clusters_test.go b/libs/databrickscfg/cfgpickers/clusters_test.go new file mode 100644 index 00000000..362d6904 --- /dev/null +++ b/libs/databrickscfg/cfgpickers/clusters_test.go @@ -0,0 +1,146 @@ +package cfgpickers + +import ( + "bytes" + "context" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/qa" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/stretchr/testify/require" +) + +func TestIsCompatible(t *testing.T) { + require.True(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "13.2.x-aarch64-scala2.12", + DataSecurityMode: compute.DataSecurityModeUserIsolation, + }, "13.0")) + require.False(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "13.2.x-aarch64-scala2.12", + DataSecurityMode: compute.DataSecurityModeNone, + }, "13.0")) + require.False(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "9.1.x-photon-scala2.12", + DataSecurityMode: compute.DataSecurityModeNone, + }, "13.0")) + require.False(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "9.1.x-photon-scala2.12", + DataSecurityMode: compute.DataSecurityModeNone, + }, "10.0")) + require.False(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "custom-9.1.x-photon-scala2.12", + DataSecurityMode: compute.DataSecurityModeNone, + }, "14.0")) +} + +func TestIsCompatibleWithSnapshots(t *testing.T) { + require.True(t, IsCompatibleWithUC(compute.ClusterDetails{ + SparkVersion: "14.x-snapshot-cpu-ml-scala2.12", + DataSecurityMode: compute.DataSecurityModeUserIsolation, + }, "14.0")) +} + +func TestFirstCompatibleCluster(t *testing.T) { + cfg, server := qa.HTTPFixtures{ + { + Method: "GET", + Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS", + Response: compute.ListClustersResponse{ + Clusters: []compute.ClusterDetails{ + { + ClusterId: "abc-id", + ClusterName: "first shared", + DataSecurityMode: compute.DataSecurityModeUserIsolation, + SparkVersion: "12.2.x-whatever", + State: compute.StateRunning, + }, + { + ClusterId: "bcd-id", + ClusterName: "second personal", + DataSecurityMode: compute.DataSecurityModeSingleUser, + SparkVersion: "14.5.x-whatever", + State: compute.StateRunning, + SingleUserName: "serge", + }, + }, + }, + }, + { + Method: "GET", + Resource: "/api/2.0/preview/scim/v2/Me", + Response: iam.User{ + UserName: "serge", + }, + }, + { + Method: "GET", + Resource: "/api/2.0/clusters/spark-versions", + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ + { + Key: "14.5.x-whatever", + Name: "14.5 (Awesome)", + }, + }, + }, + }, + }.Config(t) + defer server.Close() + w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) + + ctx := context.Background() + ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "...")) + clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1")) + require.NoError(t, err) + require.Equal(t, "bcd-id", clusterID) +} + +func TestNoCompatibleClusters(t *testing.T) { + cfg, server := qa.HTTPFixtures{ + { + Method: "GET", + Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS", + Response: compute.ListClustersResponse{ + Clusters: []compute.ClusterDetails{ + { + ClusterId: "abc-id", + ClusterName: "first shared", + DataSecurityMode: compute.DataSecurityModeUserIsolation, + SparkVersion: "12.2.x-whatever", + State: compute.StateRunning, + }, + }, + }, + }, + { + Method: "GET", + Resource: "/api/2.0/preview/scim/v2/Me", + Response: iam.User{ + UserName: "serge", + }, + }, + { + Method: "GET", + Resource: "/api/2.0/clusters/spark-versions", + Response: compute.GetSparkVersionsResponse{ + Versions: []compute.SparkVersion{ + { + Key: "14.5.x-whatever", + Name: "14.5 (Awesome)", + }, + }, + }, + }, + }.Config(t) + defer server.Close() + w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) + + ctx := context.Background() + ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "...")) + _, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1")) + require.Equal(t, ErrNoCompatibleClusters, err) +} diff --git a/libs/databrickscfg/cfgpickers/warehouses.go b/libs/databrickscfg/cfgpickers/warehouses.go new file mode 100644 index 00000000..65b5f8c8 --- /dev/null +++ b/libs/databrickscfg/cfgpickers/warehouses.go @@ -0,0 +1,65 @@ +package cfgpickers + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/fatih/color" +) + +var ErrNoCompatibleWarehouses = errors.New("no compatible warehouses") + +type warehouseFilter func(sql.EndpointInfo) bool + +func WithWarehouseTypes(types ...sql.EndpointInfoWarehouseType) func(sql.EndpointInfo) bool { + allowed := map[sql.EndpointInfoWarehouseType]bool{} + for _, v := range types { + allowed[v] = true + } + return func(ei sql.EndpointInfo) bool { + return allowed[ei.WarehouseType] + } +} + +func AskForWarehouse(ctx context.Context, w *databricks.WorkspaceClient, filters ...warehouseFilter) (string, error) { + all, err := w.Warehouses.ListAll(ctx, sql.ListWarehousesRequest{}) + if err != nil { + return "", fmt.Errorf("list warehouses: %w", err) + } + var lastWarehouseID string + names := map[string]string{} + for _, warehouse := range all { + var skip bool + for _, filter := range filters { + if !filter(warehouse) { + skip = true + } + } + if skip { + continue + } + var state string + switch warehouse.State { + case sql.StateRunning: + state = color.GreenString(warehouse.State.String()) + case sql.StateStopped, sql.StateDeleted, sql.StateStopping, sql.StateDeleting: + state = color.RedString(warehouse.State.String()) + default: + state = color.BlueString(warehouse.State.String()) + } + visibleTouser := fmt.Sprintf("%s (%s %s)", warehouse.Name, state, warehouse.WarehouseType) + names[visibleTouser] = warehouse.Id + lastWarehouseID = warehouse.Id + } + if len(names) == 0 { + return "", ErrNoCompatibleWarehouses + } + if len(names) == 1 { + return lastWarehouseID, nil + } + return cmdio.Select(ctx, names, "Choose SQL Warehouse") +} diff --git a/libs/databrickscfg/cfgpickers/warehouses_test.go b/libs/databrickscfg/cfgpickers/warehouses_test.go new file mode 100644 index 00000000..d6030b49 --- /dev/null +++ b/libs/databrickscfg/cfgpickers/warehouses_test.go @@ -0,0 +1,66 @@ +package cfgpickers + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/qa" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFirstCompatibleWarehouse(t *testing.T) { + cfg, server := qa.HTTPFixtures{ + { + Method: "GET", + Resource: "/api/2.0/sql/warehouses?", + Response: sql.ListWarehousesResponse{ + Warehouses: []sql.EndpointInfo{ + { + Id: "efg-id", + Name: "First PRO Warehouse", + WarehouseType: sql.EndpointInfoWarehouseTypePro, + }, + { + Id: "ghe-id", + Name: "Second UNKNOWN Warehouse", + WarehouseType: sql.EndpointInfoWarehouseTypeTypeUnspecified, + }, + }, + }, + }, + }.Config(t) + defer server.Close() + w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) + + ctx := context.Background() + clusterID, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro)) + require.NoError(t, err) + assert.Equal(t, "efg-id", clusterID) +} + +func TestNoCompatibleWarehouses(t *testing.T) { + cfg, server := qa.HTTPFixtures{ + { + Method: "GET", + Resource: "/api/2.0/sql/warehouses?", + Response: sql.ListWarehousesResponse{ + Warehouses: []sql.EndpointInfo{ + { + Id: "efg-id", + Name: "...", + WarehouseType: sql.EndpointInfoWarehouseTypeClassic, + }, + }, + }, + }, + }.Config(t) + defer server.Close() + w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg))) + + ctx := context.Background() + _, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro)) + assert.Equal(t, ErrNoCompatibleWarehouses, err) +}