diff --git a/cmd/api/api.go b/cmd/api/api.go new file mode 100644 index 00000000..9e1c5fb7 --- /dev/null +++ b/cmd/api/api.go @@ -0,0 +1,85 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + + "github.com/databricks/bricks/cmd/root" + "github.com/databricks/databricks-sdk-go/databricks" + "github.com/databricks/databricks-sdk-go/databricks/client" + "github.com/spf13/cobra" +) + +var apiCmd = &cobra.Command{ + Use: "api", + Short: "Perform Databricks API call", +} + +func requestBody(arg string) (any, error) { + if arg == "" { + return nil, nil + } + + // Load request from file if it starts with '@' (like curl). + if arg[0] == '@' { + path := arg[1:] + buf, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("error reading %s: %w", path, err) + } + return buf, nil + } + + return arg, nil +} + +func makeCommand(method string) *cobra.Command { + var bodyArgument string + + command := &cobra.Command{ + Use: strings.ToLower(method), + Args: cobra.ExactArgs(1), + Short: fmt.Sprintf("Perform %s request", method), + RunE: func(cmd *cobra.Command, args []string) error { + var path = args[0] + var response any + + request, err := requestBody(bodyArgument) + if err != nil { + return err + } + + api := client.New(&databricks.Config{}) + err = api.Do(cmd.Context(), method, path, request, &response) + if err != nil { + return err + } + + if response != nil { + enc := json.NewEncoder(cmd.OutOrStdout()) + enc.SetIndent("", " ") + enc.Encode(response) + } + + return nil + }, + } + + command.Flags().StringVar(&bodyArgument, "body", "", "Request body") + return command +} + +func init() { + apiCmd.AddCommand( + makeCommand(http.MethodGet), + makeCommand(http.MethodHead), + makeCommand(http.MethodPost), + makeCommand(http.MethodPut), + makeCommand(http.MethodPatch), + makeCommand(http.MethodDelete), + ) + root.RootCmd.AddCommand(apiCmd) +} diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go new file mode 100644 index 00000000..09d72b30 --- /dev/null +++ b/cmd/api/api_test.go @@ -0,0 +1,40 @@ +package api + +import ( + "fmt" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestBodyEmpty(t *testing.T) { + out, err := requestBody("") + require.NoError(t, err) + assert.Equal(t, nil, out) +} + +func TestRequestBodyString(t *testing.T) { + out, err := requestBody("foo") + require.NoError(t, err) + assert.Equal(t, "foo", out) +} + +func TestRequestBodyFile(t *testing.T) { + var fpath string + var payload = []byte("hello world\n") + + { + f, err := os.Create(path.Join(t.TempDir(), "file")) + require.NoError(t, err) + f.Write(payload) + f.Close() + fpath = f.Name() + } + + out, err := requestBody(fmt.Sprintf("@%s", fpath)) + require.NoError(t, err) + assert.Equal(t, payload, out) +} diff --git a/cmd/root/root.go b/cmd/root/root.go index 18314b1c..186dd3aa 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -36,7 +36,7 @@ func (lw *levelWriter) Write(p []byte) (n int, err error) { a := string(p) for _, l := range *lw { if strings.Contains(a, l) { - return os.Stdout.Write(p) + return os.Stderr.Write(p) } } return diff --git a/ext/databricks-sdk-go b/ext/databricks-sdk-go index 71789bb5..b719dadd 160000 --- a/ext/databricks-sdk-go +++ b/ext/databricks-sdk-go @@ -1 +1 @@ -Subproject commit 71789bb56a381e3c14f8136c69d00a98d8536a70 +Subproject commit b719dadd27a5cb6c67db0b6ddef5458ec31cc8c0 diff --git a/internal/api_test.go b/internal/api_test.go new file mode 100644 index 00000000..ffc01eae --- /dev/null +++ b/internal/api_test.go @@ -0,0 +1,54 @@ +package internal + +import ( + "encoding/json" + "fmt" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + _ "github.com/databricks/bricks/cmd/api" +) + +func TestAccApiGet(t *testing.T) { + t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) + + stdout, _, err := run(t, "api", "get", "/api/2.0/preview/scim/v2/Me") + require.NoError(t, err) + + // Deserialize SCIM API response. + var out map[string]any + err = json.Unmarshal(stdout.Bytes(), &out) + require.NoError(t, err) + + // Assert that the output somewhat makes sense for the SCIM API. + assert.Equal(t, true, out["active"]) + assert.NotNil(t, out["id"]) +} + +func TestAccApiPost(t *testing.T) { + env := GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + if env == "gcp" { + t.Skip("DBFS REST API is disabled on gcp") + } + + dbfsPath := filepath.Join("/tmp/bricks/integration", RandomName("api-post")) + requestPath := writeFile(t, "body.json", fmt.Sprintf(`{ + "path": "%s" + }`, dbfsPath)) + + // Post to mkdir + { + _, _, err := run(t, "api", "post", "--body=@"+requestPath, "/api/2.0/dbfs/mkdirs") + require.NoError(t, err) + } + + // Post to delete + { + _, _, err := run(t, "api", "post", "--body=@"+requestPath, "/api/2.0/dbfs/delete") + require.NoError(t, err) + } +} diff --git a/internal/helpers.go b/internal/helpers.go index 7ca9f3a6..940a0965 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -1,12 +1,17 @@ package internal import ( + "bytes" "fmt" "math/rand" "os" + "path/filepath" "strings" "testing" "time" + + "github.com/databricks/bricks/cmd/root" + "github.com/stretchr/testify/require" ) const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" @@ -33,3 +38,29 @@ func RandomName(prefix ...string) string { } return string(b) } + +func run(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer, error) { + var stdout bytes.Buffer + var stderr bytes.Buffer + root := root.RootCmd + root.SetOut(&stdout) + root.SetErr(&stderr) + root.SetArgs(args) + _, err := root.ExecuteC() + if stdout.Len() > 0 { + t.Logf("[stdout]: %s", stdout.String()) + } + if stderr.Len() > 0 { + t.Logf("[stderr]: %s", stderr.String()) + } + return stdout, stderr, err +} + +func writeFile(t *testing.T, name string, body string) string { + f, err := os.Create(filepath.Join(t.TempDir(), name)) + require.NoError(t, err) + _, err = f.WriteString(body) + require.NoError(t, err) + f.Close() + return f.Name() +} diff --git a/main.go b/main.go index b5127132..805c5bb6 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + _ "github.com/databricks/bricks/cmd/api" _ "github.com/databricks/bricks/cmd/configure" _ "github.com/databricks/bricks/cmd/fs" _ "github.com/databricks/bricks/cmd/init"