Improved usability of `databricks auth login ... --configure-cluster` flow by displaying cluster type and runtime version (#956)

This PR adds selectors for Databricks-connect compatible clusters and
SQL warehouses

Tested in https://github.com/databricks/cli/pull/914
This commit is contained in:
Serge Smertin 2023-11-09 17:38:45 +01:00 committed by GitHub
parent f111b0846e
commit 3284a8c56c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 475 additions and 11 deletions

View File

@ -8,9 +8,9 @@ import (
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/spf13/cobra"
)
@ -28,6 +28,8 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
return nil
}
const minimalDbConnectVersion = "13.1"
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
cmd := &cobra.Command{
Use: "login [HOST]",
@ -95,19 +97,12 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
return err
}
ctx := cmd.Context()
promptSpinner := cmdio.Spinner(ctx)
promptSpinner <- "Loading list of clusters to select from"
names, err := w.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{})
close(promptSpinner)
if err != nil {
return fmt.Errorf("failed to load clusters list. Original error: %w", err)
}
clusterId, err := cmdio.Select(ctx, names, "Choose cluster")
clusterID, err := cfgpickers.AskForCluster(ctx, w,
cfgpickers.WithDatabricksConnect(minimalDbConnectVersion))
if err != nil {
return err
}
cfg.ClusterID = clusterId
cfg.ClusterID = clusterID
}
if profileName != "" {

View File

@ -0,0 +1,192 @@
package cfgpickers
import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/fatih/color"
"github.com/manifoldco/promptui"
"golang.org/x/mod/semver"
)
var minUcRuntime = canonicalVersion("v12.0")
var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`)
var dbrSnapshotVersionRegex = regexp.MustCompile(`^(\d+)\.x-snapshot.*`)
func canonicalVersion(v string) string {
return semver.Canonical("v" + strings.TrimPrefix(v, "v"))
}
func GetRuntimeVersion(cluster compute.ClusterDetails) (string, bool) {
match := dbrVersionRegex.FindStringSubmatch(cluster.SparkVersion)
if len(match) < 1 {
match = dbrSnapshotVersionRegex.FindStringSubmatch(cluster.SparkVersion)
if len(match) > 1 {
// we return 14.999 for 14.x-snapshot for semver.Compare() to work properly
return fmt.Sprintf("%s.999", match[1]), true
}
return "", false
}
return match[1], true
}
func IsCompatibleWithUC(cluster compute.ClusterDetails, minVersion string) bool {
minVersion = canonicalVersion(minVersion)
if semver.Compare(minUcRuntime, minVersion) >= 0 {
return false
}
runtimeVersion, ok := GetRuntimeVersion(cluster)
if !ok {
return false
}
clusterRuntime := canonicalVersion(runtimeVersion)
if semver.Compare(minVersion, clusterRuntime) > 0 {
return false
}
switch cluster.DataSecurityMode {
case compute.DataSecurityModeUserIsolation, compute.DataSecurityModeSingleUser:
return true
default:
return false
}
}
var ErrNoCompatibleClusters = errors.New("no compatible clusters found")
type compatibleCluster struct {
compute.ClusterDetails
versionName string
}
func (v compatibleCluster) Access() string {
switch v.DataSecurityMode {
case compute.DataSecurityModeUserIsolation:
return "Shared"
case compute.DataSecurityModeSingleUser:
return "Assigned"
default:
return "Unknown"
}
}
func (v compatibleCluster) Runtime() string {
runtime, _, _ := strings.Cut(v.versionName, " (")
return runtime
}
func (v compatibleCluster) State() string {
state := v.ClusterDetails.State
switch state {
case compute.StateRunning, compute.StateResizing:
return color.GreenString(state.String())
case compute.StateError, compute.StateTerminated, compute.StateTerminating, compute.StateUnknown:
return color.RedString(state.String())
default:
return color.BlueString(state.String())
}
}
type clusterFilter func(cluster *compute.ClusterDetails, me *iam.User) bool
func WithDatabricksConnect(minVersion string) func(*compute.ClusterDetails, *iam.User) bool {
return func(cluster *compute.ClusterDetails, me *iam.User) bool {
if !IsCompatibleWithUC(*cluster, minVersion) {
return false
}
switch cluster.ClusterSource {
case compute.ClusterSourceJob,
compute.ClusterSourceModels,
compute.ClusterSourcePipeline,
compute.ClusterSourcePipelineMaintenance,
compute.ClusterSourceSql:
// only UI and API clusters are usable for DBConnect.
// `CanUseClient: "NOTEBOOKS"`` didn't seem to have an effect.
return false
}
if cluster.SingleUserName != "" && cluster.SingleUserName != me.UserName {
return false
}
return true
}
}
func loadInteractiveClusters(ctx context.Context, w *databricks.WorkspaceClient, filters []clusterFilter) ([]compatibleCluster, error) {
promptSpinner := cmdio.Spinner(ctx)
promptSpinner <- "Loading list of clusters to select from"
defer close(promptSpinner)
all, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{
CanUseClient: "NOTEBOOKS",
})
if err != nil {
return nil, fmt.Errorf("list clusters: %w", err)
}
me, err := w.CurrentUser.Me(ctx)
if err != nil {
return nil, fmt.Errorf("current user: %w", err)
}
versions := map[string]string{}
sv, err := w.Clusters.SparkVersions(ctx)
if err != nil {
return nil, fmt.Errorf("list runtime versions: %w", err)
}
for _, v := range sv.Versions {
versions[v.Key] = v.Name
}
var compatible []compatibleCluster
for _, cluster := range all {
var skip bool
for _, filter := range filters {
if !filter(&cluster, me) {
skip = true
}
}
if skip {
continue
}
compatible = append(compatible, compatibleCluster{
ClusterDetails: cluster,
versionName: versions[cluster.SparkVersion],
})
}
return compatible, nil
}
func AskForCluster(ctx context.Context, w *databricks.WorkspaceClient, filters ...clusterFilter) (string, error) {
compatible, err := loadInteractiveClusters(ctx, w, filters)
if err != nil {
return "", fmt.Errorf("load: %w", err)
}
if len(compatible) == 0 {
return "", ErrNoCompatibleClusters
}
if len(compatible) == 1 {
return compatible[0].ClusterId, nil
}
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: "Choose compatible cluster",
Items: compatible,
Searcher: func(input string, idx int) bool {
lower := strings.ToLower(compatible[idx].ClusterName)
return strings.Contains(lower, input)
},
StartInSearchMode: true,
Templates: &promptui.SelectTemplates{
Label: "{{.ClusterName | faint}}",
Active: `{{.ClusterName | bold}} ({{.State}} {{.Access}} Runtime {{.Runtime}}) ({{.ClusterId | faint}})`,
Inactive: `{{.ClusterName}} ({{.State}} {{.Access}} Runtime {{.Runtime}})`,
Selected: `{{ "Configured cluster" | faint }}: {{ .ClusterName | bold }} ({{.ClusterId | faint}})`,
},
})
if err != nil {
return "", err
}
return compatible[i].ClusterId, nil
}

View File

@ -0,0 +1,146 @@
package cfgpickers
import (
"bytes"
"context"
"testing"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/qa"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/stretchr/testify/require"
)
func TestIsCompatible(t *testing.T) {
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "13.2.x-aarch64-scala2.12",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "13.2.x-aarch64-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "10.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "custom-9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "14.0"))
}
func TestIsCompatibleWithSnapshots(t *testing.T) {
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "14.x-snapshot-cpu-ml-scala2.12",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
}, "14.0"))
}
func TestFirstCompatibleCluster(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
Response: compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{
{
ClusterId: "abc-id",
ClusterName: "first shared",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
SparkVersion: "12.2.x-whatever",
State: compute.StateRunning,
},
{
ClusterId: "bcd-id",
ClusterName: "second personal",
DataSecurityMode: compute.DataSecurityModeSingleUser,
SparkVersion: "14.5.x-whatever",
State: compute.StateRunning,
SingleUserName: "serge",
},
},
},
},
{
Method: "GET",
Resource: "/api/2.0/preview/scim/v2/Me",
Response: iam.User{
UserName: "serge",
},
},
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Key: "14.5.x-whatever",
Name: "14.5 (Awesome)",
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
require.NoError(t, err)
require.Equal(t, "bcd-id", clusterID)
}
func TestNoCompatibleClusters(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
Response: compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{
{
ClusterId: "abc-id",
ClusterName: "first shared",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
SparkVersion: "12.2.x-whatever",
State: compute.StateRunning,
},
},
},
},
{
Method: "GET",
Resource: "/api/2.0/preview/scim/v2/Me",
Response: iam.User{
UserName: "serge",
},
},
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Key: "14.5.x-whatever",
Name: "14.5 (Awesome)",
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
_, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
require.Equal(t, ErrNoCompatibleClusters, err)
}

View File

@ -0,0 +1,65 @@
package cfgpickers
import (
"context"
"errors"
"fmt"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/fatih/color"
)
var ErrNoCompatibleWarehouses = errors.New("no compatible warehouses")
type warehouseFilter func(sql.EndpointInfo) bool
func WithWarehouseTypes(types ...sql.EndpointInfoWarehouseType) func(sql.EndpointInfo) bool {
allowed := map[sql.EndpointInfoWarehouseType]bool{}
for _, v := range types {
allowed[v] = true
}
return func(ei sql.EndpointInfo) bool {
return allowed[ei.WarehouseType]
}
}
func AskForWarehouse(ctx context.Context, w *databricks.WorkspaceClient, filters ...warehouseFilter) (string, error) {
all, err := w.Warehouses.ListAll(ctx, sql.ListWarehousesRequest{})
if err != nil {
return "", fmt.Errorf("list warehouses: %w", err)
}
var lastWarehouseID string
names := map[string]string{}
for _, warehouse := range all {
var skip bool
for _, filter := range filters {
if !filter(warehouse) {
skip = true
}
}
if skip {
continue
}
var state string
switch warehouse.State {
case sql.StateRunning:
state = color.GreenString(warehouse.State.String())
case sql.StateStopped, sql.StateDeleted, sql.StateStopping, sql.StateDeleting:
state = color.RedString(warehouse.State.String())
default:
state = color.BlueString(warehouse.State.String())
}
visibleTouser := fmt.Sprintf("%s (%s %s)", warehouse.Name, state, warehouse.WarehouseType)
names[visibleTouser] = warehouse.Id
lastWarehouseID = warehouse.Id
}
if len(names) == 0 {
return "", ErrNoCompatibleWarehouses
}
if len(names) == 1 {
return lastWarehouseID, nil
}
return cmdio.Select(ctx, names, "Choose SQL Warehouse")
}

View File

@ -0,0 +1,66 @@
package cfgpickers
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/qa"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirstCompatibleWarehouse(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/sql/warehouses?",
Response: sql.ListWarehousesResponse{
Warehouses: []sql.EndpointInfo{
{
Id: "efg-id",
Name: "First PRO Warehouse",
WarehouseType: sql.EndpointInfoWarehouseTypePro,
},
{
Id: "ghe-id",
Name: "Second UNKNOWN Warehouse",
WarehouseType: sql.EndpointInfoWarehouseTypeTypeUnspecified,
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
ctx := context.Background()
clusterID, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro))
require.NoError(t, err)
assert.Equal(t, "efg-id", clusterID)
}
func TestNoCompatibleWarehouses(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/sql/warehouses?",
Response: sql.ListWarehousesResponse{
Warehouses: []sql.EndpointInfo{
{
Id: "efg-id",
Name: "...",
WarehouseType: sql.EndpointInfoWarehouseTypeClassic,
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
ctx := context.Background()
_, err := AskForWarehouse(ctx, w, WithWarehouseTypes(sql.EndpointInfoWarehouseTypePro))
assert.Equal(t, ErrNoCompatibleWarehouses, err)
}