mirror of https://github.com/databricks/cli.git
616 lines
16 KiB
Go
616 lines
16 KiB
Go
package internal
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/databricks/cli/cmd/root"
|
|
"github.com/databricks/cli/libs/flags"
|
|
|
|
"github.com/databricks/cli/cmd"
|
|
_ "github.com/databricks/cli/cmd/version"
|
|
"github.com/databricks/cli/libs/cmdio"
|
|
"github.com/databricks/cli/libs/filer"
|
|
"github.com/databricks/databricks-sdk-go"
|
|
"github.com/databricks/databricks-sdk-go/apierr"
|
|
"github.com/databricks/databricks-sdk-go/service/catalog"
|
|
"github.com/databricks/databricks-sdk-go/service/compute"
|
|
"github.com/databricks/databricks-sdk-go/service/files"
|
|
"github.com/databricks/databricks-sdk-go/service/jobs"
|
|
"github.com/databricks/databricks-sdk-go/service/workspace"
|
|
"github.com/spf13/cobra"
|
|
"github.com/spf13/pflag"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
_ "github.com/databricks/cli/cmd/workspace"
|
|
)
|
|
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
|
|
// GetEnvOrSkipTest proceeds with test only with that env variable
|
|
func GetEnvOrSkipTest(t *testing.T, name string) string {
|
|
value := os.Getenv(name)
|
|
if value == "" {
|
|
t.Skipf("Environment variable %s is missing", name)
|
|
}
|
|
return value
|
|
}
|
|
|
|
// RandomName gives random name with optional prefix. e.g. qa.RandomName("tf-")
|
|
func RandomName(prefix ...string) string {
|
|
randLen := 12
|
|
b := make([]byte, randLen)
|
|
for i := range b {
|
|
b[i] = charset[rand.Intn(randLen)]
|
|
}
|
|
if len(prefix) > 0 {
|
|
return fmt.Sprintf("%s%s", strings.Join(prefix, ""), b)
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
// Helper for running the root command in the background.
|
|
// It ensures that the background goroutine terminates upon
|
|
// test completion through cancelling the command context.
|
|
type cobraTestRunner struct {
|
|
*testing.T
|
|
|
|
args []string
|
|
stdout bytes.Buffer
|
|
stderr bytes.Buffer
|
|
stdinR *io.PipeReader
|
|
stdinW *io.PipeWriter
|
|
|
|
ctx context.Context
|
|
|
|
// Line-by-line output.
|
|
// Background goroutines populate these channels by reading from stdout/stderr pipes.
|
|
stdoutLines <-chan string
|
|
stderrLines <-chan string
|
|
|
|
errch <-chan error
|
|
}
|
|
|
|
func consumeLines(ctx context.Context, wg *sync.WaitGroup, r io.Reader) <-chan string {
|
|
ch := make(chan string, 1000)
|
|
wg.Add(1)
|
|
go func() {
|
|
defer close(ch)
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(r)
|
|
for scanner.Scan() {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case ch <- scanner.Text():
|
|
}
|
|
}
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
func (t *cobraTestRunner) registerFlagCleanup(c *cobra.Command) {
|
|
// Find target command that will be run. Example: if the command run is `databricks fs cp`,
|
|
// target command corresponds to `cp`
|
|
targetCmd, _, err := c.Find(t.args)
|
|
if err != nil && strings.HasPrefix(err.Error(), "unknown command") {
|
|
// even if command is unknown, we can proceed
|
|
require.NotNil(t, targetCmd)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Force initialization of default flags.
|
|
// These are initialized by cobra at execution time and would otherwise
|
|
// not be cleaned up by the cleanup function below.
|
|
targetCmd.InitDefaultHelpFlag()
|
|
targetCmd.InitDefaultVersionFlag()
|
|
|
|
// Restore flag values to their original value on test completion.
|
|
targetCmd.Flags().VisitAll(func(f *pflag.Flag) {
|
|
v := reflect.ValueOf(f.Value)
|
|
if v.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
}
|
|
// Store copy of the current flag value.
|
|
reset := reflect.New(v.Type()).Elem()
|
|
reset.Set(v)
|
|
t.Cleanup(func() {
|
|
v.Set(reset)
|
|
})
|
|
})
|
|
}
|
|
|
|
// Like [cobraTestRunner.Eventually], but more specific
|
|
func (t *cobraTestRunner) WaitForTextPrinted(text string, timeout time.Duration) {
|
|
t.Eventually(func() bool {
|
|
currentStdout := t.stdout.String()
|
|
return strings.Contains(currentStdout, text)
|
|
}, timeout, 50*time.Millisecond)
|
|
}
|
|
|
|
func (t *cobraTestRunner) WaitForOutput(text string, timeout time.Duration) {
|
|
require.Eventually(t.T, func() bool {
|
|
currentStdout := t.stdout.String()
|
|
currentErrout := t.stderr.String()
|
|
return strings.Contains(currentStdout, text) || strings.Contains(currentErrout, text)
|
|
}, timeout, 50*time.Millisecond)
|
|
}
|
|
|
|
func (t *cobraTestRunner) WithStdin() {
|
|
reader, writer := io.Pipe()
|
|
t.stdinR = reader
|
|
t.stdinW = writer
|
|
}
|
|
|
|
func (t *cobraTestRunner) CloseStdin() {
|
|
if t.stdinW == nil {
|
|
panic("no standard input configured")
|
|
}
|
|
t.stdinW.Close()
|
|
}
|
|
|
|
func (t *cobraTestRunner) SendText(text string) {
|
|
if t.stdinW == nil {
|
|
panic("no standard input configured")
|
|
}
|
|
t.stdinW.Write([]byte(text + "\n"))
|
|
}
|
|
|
|
func (t *cobraTestRunner) RunBackground() {
|
|
var stdoutR, stderrR io.Reader
|
|
var stdoutW, stderrW io.WriteCloser
|
|
stdoutR, stdoutW = io.Pipe()
|
|
stderrR, stderrW = io.Pipe()
|
|
ctx := cmdio.NewContext(t.ctx, &cmdio.Logger{
|
|
Mode: flags.ModeAppend,
|
|
Reader: bufio.Reader{},
|
|
Writer: stderrW,
|
|
})
|
|
|
|
cli := cmd.New(ctx)
|
|
cli.SetOut(stdoutW)
|
|
cli.SetErr(stderrW)
|
|
cli.SetArgs(t.args)
|
|
if t.stdinW != nil {
|
|
cli.SetIn(t.stdinR)
|
|
}
|
|
|
|
// Register cleanup function to restore flags to their original values
|
|
// once test has been executed. This is needed because flag values reside
|
|
// in a global singleton data-structure, and thus subsequent tests might
|
|
// otherwise interfere with each other
|
|
t.registerFlagCleanup(cli)
|
|
|
|
errch := make(chan error)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
// Tee stdout/stderr to buffers.
|
|
stdoutR = io.TeeReader(stdoutR, &t.stdout)
|
|
stderrR = io.TeeReader(stderrR, &t.stderr)
|
|
|
|
// Consume stdout/stderr line-by-line.
|
|
var wg sync.WaitGroup
|
|
t.stdoutLines = consumeLines(ctx, &wg, stdoutR)
|
|
t.stderrLines = consumeLines(ctx, &wg, stderrR)
|
|
|
|
// Run command in background.
|
|
go func() {
|
|
err := root.Execute(ctx, cli)
|
|
if err != nil {
|
|
t.Logf("Error running command: %s", err)
|
|
}
|
|
|
|
// Close pipes to signal EOF.
|
|
stdoutW.Close()
|
|
stderrW.Close()
|
|
|
|
// Wait for the [consumeLines] routines to finish now that
|
|
// the pipes they're reading from have closed.
|
|
wg.Wait()
|
|
|
|
if t.stdout.Len() > 0 {
|
|
// Make a copy of the buffer such that it remains "unread".
|
|
scanner := bufio.NewScanner(bytes.NewBuffer(t.stdout.Bytes()))
|
|
for scanner.Scan() {
|
|
t.Logf("[databricks stdout]: %s", scanner.Text())
|
|
}
|
|
}
|
|
|
|
if t.stderr.Len() > 0 {
|
|
// Make a copy of the buffer such that it remains "unread".
|
|
scanner := bufio.NewScanner(bytes.NewBuffer(t.stderr.Bytes()))
|
|
for scanner.Scan() {
|
|
t.Logf("[databricks stderr]: %s", scanner.Text())
|
|
}
|
|
}
|
|
|
|
// Reset context on command for the next test.
|
|
// These commands are globals so we have to clean up to the best of our ability after each run.
|
|
// See https://github.com/spf13/cobra/blob/a6f198b635c4b18fff81930c40d464904e55b161/command.go#L1062-L1066
|
|
//lint:ignore SA1012 cobra sets the context and doesn't clear it
|
|
cli.SetContext(nil)
|
|
|
|
// Make caller aware of error.
|
|
errch <- err
|
|
close(errch)
|
|
}()
|
|
|
|
// Ensure command terminates upon test completion (success or failure).
|
|
t.Cleanup(func() {
|
|
// Signal termination of command.
|
|
cancel()
|
|
// Wait for goroutine to finish.
|
|
<-errch
|
|
})
|
|
|
|
t.errch = errch
|
|
}
|
|
|
|
func (t *cobraTestRunner) Run() (bytes.Buffer, bytes.Buffer, error) {
|
|
t.RunBackground()
|
|
err := <-t.errch
|
|
return t.stdout, t.stderr, err
|
|
}
|
|
|
|
// Like [require.Eventually] but errors if the underlying command has failed.
|
|
func (c *cobraTestRunner) Eventually(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) {
|
|
ch := make(chan bool, 1)
|
|
|
|
timer := time.NewTimer(waitFor)
|
|
defer timer.Stop()
|
|
|
|
ticker := time.NewTicker(tick)
|
|
defer ticker.Stop()
|
|
|
|
// Kick off condition check immediately.
|
|
go func() { ch <- condition() }()
|
|
|
|
for tick := ticker.C; ; {
|
|
select {
|
|
case err := <-c.errch:
|
|
require.Fail(c, "Command failed", err)
|
|
return
|
|
case <-timer.C:
|
|
require.Fail(c, "Condition never satisfied", msgAndArgs...)
|
|
return
|
|
case <-tick:
|
|
tick = nil
|
|
go func() { ch <- condition() }()
|
|
case v := <-ch:
|
|
if v {
|
|
return
|
|
}
|
|
tick = ticker.C
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *cobraTestRunner) RunAndExpectOutput(heredoc string) {
|
|
stdout, _, err := t.Run()
|
|
require.NoError(t, err)
|
|
require.Equal(t, cmdio.Heredoc(heredoc), strings.TrimSpace(stdout.String()))
|
|
}
|
|
|
|
func (t *cobraTestRunner) RunAndParseJSON(v any) {
|
|
stdout, _, err := t.Run()
|
|
require.NoError(t, err)
|
|
err = json.Unmarshal(stdout.Bytes(), &v)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func NewCobraTestRunner(t *testing.T, args ...string) *cobraTestRunner {
|
|
return &cobraTestRunner{
|
|
T: t,
|
|
ctx: context.Background(),
|
|
args: args,
|
|
}
|
|
}
|
|
|
|
func NewCobraTestRunnerWithContext(t *testing.T, ctx context.Context, args ...string) *cobraTestRunner {
|
|
return &cobraTestRunner{
|
|
T: t,
|
|
ctx: ctx,
|
|
args: args,
|
|
}
|
|
}
|
|
|
|
func RequireSuccessfulRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer) {
|
|
t.Logf("run args: [%s]", strings.Join(args, ", "))
|
|
c := NewCobraTestRunner(t, args...)
|
|
stdout, stderr, err := c.Run()
|
|
require.NoError(t, err)
|
|
return stdout, stderr
|
|
}
|
|
|
|
func RequireErrorRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer, error) {
|
|
c := NewCobraTestRunner(t, args...)
|
|
stdout, stderr, err := c.Run()
|
|
require.Error(t, err)
|
|
return stdout, stderr, err
|
|
}
|
|
|
|
func writeFile(t *testing.T, name string, body string) string {
|
|
f, err := os.Create(filepath.Join(t.TempDir(), name))
|
|
require.NoError(t, err)
|
|
_, err = f.WriteString(body)
|
|
require.NoError(t, err)
|
|
f.Close()
|
|
return f.Name()
|
|
}
|
|
|
|
func GenerateNotebookTasks(notebookPath string, versions []string, nodeTypeId string) []jobs.SubmitTask {
|
|
tasks := make([]jobs.SubmitTask, 0)
|
|
for i := 0; i < len(versions); i++ {
|
|
task := jobs.SubmitTask{
|
|
TaskKey: fmt.Sprintf("notebook_%s", strings.ReplaceAll(versions[i], ".", "_")),
|
|
NotebookTask: &jobs.NotebookTask{
|
|
NotebookPath: notebookPath,
|
|
},
|
|
NewCluster: &compute.ClusterSpec{
|
|
SparkVersion: versions[i],
|
|
NumWorkers: 1,
|
|
NodeTypeId: nodeTypeId,
|
|
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
|
},
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
return tasks
|
|
}
|
|
|
|
func GenerateSparkPythonTasks(notebookPath string, versions []string, nodeTypeId string) []jobs.SubmitTask {
|
|
tasks := make([]jobs.SubmitTask, 0)
|
|
for i := 0; i < len(versions); i++ {
|
|
task := jobs.SubmitTask{
|
|
TaskKey: fmt.Sprintf("spark_%s", strings.ReplaceAll(versions[i], ".", "_")),
|
|
SparkPythonTask: &jobs.SparkPythonTask{
|
|
PythonFile: notebookPath,
|
|
},
|
|
NewCluster: &compute.ClusterSpec{
|
|
SparkVersion: versions[i],
|
|
NumWorkers: 1,
|
|
NodeTypeId: nodeTypeId,
|
|
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
|
},
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
return tasks
|
|
}
|
|
|
|
func GenerateWheelTasks(wheelPath string, versions []string, nodeTypeId string) []jobs.SubmitTask {
|
|
tasks := make([]jobs.SubmitTask, 0)
|
|
for i := 0; i < len(versions); i++ {
|
|
task := jobs.SubmitTask{
|
|
TaskKey: fmt.Sprintf("whl_%s", strings.ReplaceAll(versions[i], ".", "_")),
|
|
PythonWheelTask: &jobs.PythonWheelTask{
|
|
PackageName: "my_test_code",
|
|
EntryPoint: "run",
|
|
},
|
|
NewCluster: &compute.ClusterSpec{
|
|
SparkVersion: versions[i],
|
|
NumWorkers: 1,
|
|
NodeTypeId: nodeTypeId,
|
|
DataSecurityMode: compute.DataSecurityModeUserIsolation,
|
|
},
|
|
Libraries: []compute.Library{
|
|
{Whl: wheelPath},
|
|
},
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
return tasks
|
|
}
|
|
|
|
func TemporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string {
|
|
ctx := context.Background()
|
|
me, err := w.CurrentUser.Me(ctx)
|
|
require.NoError(t, err)
|
|
|
|
basePath := fmt.Sprintf("/Users/%s/%s", me.UserName, RandomName("integration-test-wsfs-"))
|
|
|
|
t.Logf("Creating %s", basePath)
|
|
err = w.Workspace.MkdirsByPath(ctx, basePath)
|
|
require.NoError(t, err)
|
|
|
|
// Remove test directory on test completion.
|
|
t.Cleanup(func() {
|
|
t.Logf("Removing %s", basePath)
|
|
err := w.Workspace.Delete(ctx, workspace.Delete{
|
|
Path: basePath,
|
|
Recursive: true,
|
|
})
|
|
if err == nil || apierr.IsMissing(err) {
|
|
return
|
|
}
|
|
t.Logf("Unable to remove temporary workspace directory %s: %#v", basePath, err)
|
|
})
|
|
|
|
return basePath
|
|
}
|
|
|
|
func TemporaryDbfsDir(t *testing.T, w *databricks.WorkspaceClient) string {
|
|
ctx := context.Background()
|
|
path := fmt.Sprintf("/tmp/%s", RandomName("integration-test-dbfs-"))
|
|
|
|
t.Logf("Creating DBFS folder:%s", path)
|
|
err := w.Dbfs.MkdirsByPath(ctx, path)
|
|
require.NoError(t, err)
|
|
|
|
t.Cleanup(func() {
|
|
t.Logf("Removing DBFS folder:%s", path)
|
|
err := w.Dbfs.Delete(ctx, files.Delete{
|
|
Path: path,
|
|
Recursive: true,
|
|
})
|
|
if err == nil || apierr.IsMissing(err) {
|
|
return
|
|
}
|
|
t.Logf("unable to remove temporary dbfs directory %s: %#v", path, err)
|
|
})
|
|
|
|
return path
|
|
}
|
|
|
|
// Create a new UC volume in a catalog called "main" in the workspace.
|
|
func TemporaryUcVolume(t *testing.T, w *databricks.WorkspaceClient) string {
|
|
ctx := context.Background()
|
|
|
|
// Create a schema
|
|
schema, err := w.Schemas.Create(ctx, catalog.CreateSchema{
|
|
CatalogName: "main",
|
|
Name: RandomName("test-schema-"),
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
w.Schemas.Delete(ctx, catalog.DeleteSchemaRequest{
|
|
FullName: schema.FullName,
|
|
})
|
|
})
|
|
|
|
// Create a volume
|
|
volume, err := w.Volumes.Create(ctx, catalog.CreateVolumeRequestContent{
|
|
CatalogName: "main",
|
|
SchemaName: schema.Name,
|
|
Name: "my-volume",
|
|
VolumeType: catalog.VolumeTypeManaged,
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
w.Volumes.Delete(ctx, catalog.DeleteVolumeRequest{
|
|
Name: volume.FullName,
|
|
})
|
|
})
|
|
|
|
return path.Join("/Volumes", "main", schema.Name, volume.Name)
|
|
|
|
}
|
|
|
|
func TemporaryRepo(t *testing.T, w *databricks.WorkspaceClient) string {
|
|
ctx := context.Background()
|
|
me, err := w.CurrentUser.Me(ctx)
|
|
require.NoError(t, err)
|
|
|
|
repoPath := fmt.Sprintf("/Repos/%s/%s", me.UserName, RandomName("integration-test-repo-"))
|
|
|
|
t.Logf("Creating repo:%s", repoPath)
|
|
repoInfo, err := w.Repos.Create(ctx, workspace.CreateRepo{
|
|
Url: "https://github.com/databricks/cli",
|
|
Provider: "github",
|
|
Path: repoPath,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
t.Cleanup(func() {
|
|
t.Logf("Removing repo: %s", repoPath)
|
|
err := w.Repos.Delete(ctx, workspace.DeleteRepoRequest{
|
|
RepoId: repoInfo.Id,
|
|
})
|
|
if err == nil || apierr.IsMissing(err) {
|
|
return
|
|
}
|
|
t.Logf("unable to remove repo %s: %#v", repoPath, err)
|
|
})
|
|
|
|
return repoPath
|
|
}
|
|
|
|
func GetNodeTypeId(env string) string {
|
|
if env == "gcp" {
|
|
return "n1-standard-4"
|
|
} else if env == "aws" || env == "ucws" {
|
|
// aws-prod-ucws has CLOUD_ENV set to "ucws"
|
|
return "i3.xlarge"
|
|
}
|
|
return "Standard_DS4_v2"
|
|
}
|
|
|
|
func setupLocalFiler(t *testing.T) (filer.Filer, string) {
|
|
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
|
|
|
tmp := t.TempDir()
|
|
f, err := filer.NewLocalClient(tmp)
|
|
require.NoError(t, err)
|
|
|
|
return f, path.Join(filepath.ToSlash(tmp))
|
|
}
|
|
|
|
func setupWsfsFiler(t *testing.T) (filer.Filer, string) {
|
|
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
|
|
|
ctx := context.Background()
|
|
w := databricks.Must(databricks.NewWorkspaceClient())
|
|
tmpdir := TemporaryWorkspaceDir(t, w)
|
|
f, err := filer.NewWorkspaceFilesClient(w, tmpdir)
|
|
require.NoError(t, err)
|
|
|
|
// Check if we can use this API here, skip test if we cannot.
|
|
_, err = f.Read(ctx, "we_use_this_call_to_test_if_this_api_is_enabled")
|
|
var aerr *apierr.APIError
|
|
if errors.As(err, &aerr) && aerr.StatusCode == http.StatusBadRequest {
|
|
t.Skip(aerr.Message)
|
|
}
|
|
|
|
return f, tmpdir
|
|
}
|
|
|
|
func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
|
|
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
|
|
|
w := databricks.Must(databricks.NewWorkspaceClient())
|
|
tmpdir := TemporaryWorkspaceDir(t, w)
|
|
f, err := filer.NewWorkspaceFilesExtensionsClient(w, tmpdir)
|
|
require.NoError(t, err)
|
|
|
|
return f, tmpdir
|
|
}
|
|
|
|
func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
|
|
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
|
|
|
w, err := databricks.NewWorkspaceClient()
|
|
require.NoError(t, err)
|
|
|
|
tmpDir := TemporaryDbfsDir(t, w)
|
|
f, err := filer.NewDbfsClient(w, tmpDir)
|
|
require.NoError(t, err)
|
|
|
|
return f, path.Join("dbfs:/", tmpDir)
|
|
}
|
|
|
|
func setupUcVolumesFiler(t *testing.T) (filer.Filer, string) {
|
|
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
|
|
|
|
if os.Getenv("TEST_METASTORE_ID") == "" {
|
|
t.Skip("Skipping tests that require a UC Volume when metastore id is not set.")
|
|
}
|
|
|
|
w, err := databricks.NewWorkspaceClient()
|
|
require.NoError(t, err)
|
|
|
|
tmpDir := TemporaryUcVolume(t, w)
|
|
f, err := filer.NewFilesClient(w, tmpDir)
|
|
require.NoError(t, err)
|
|
|
|
return f, path.Join("dbfs:/", tmpDir)
|
|
}
|