Compare commits

...

39 Commits

Author SHA1 Message Date
shreyas-goenka 029c292cf3
Merge e9b0afb337 into 24ac8d8d59 2025-02-11 21:30:35 +05:30
shreyas-goenka 24ac8d8d59
Add acceptance tests for auth resolution (#2285)
## Changes

This PR adds acceptance tests for native Databricks auth methods: basic,
oauth, and pat.

In the future we could compare this with auth credentials used by
downstream tools like TF or the telemetry process to ensure consistent
auth credentials are picked up and used.

Note: 
We do not add acceptance tests for other auth methods like Azure because
they communicate with external endpoints. To test them locally, we would
need to set up a reverse proxy server, which is out of scope for this
change.

## Tests
N/A
2025-02-11 15:50:03 +00:00
Denis Bilenko 5d392acbef
acc: Allow mixing custom stubs with default server impl (#2334)
## Changes
- Currently if you define [[Server]] block, you disable the default
server implementation. With this change [[Server]] block takes
precedence over default server but default server remains.
- Switched mux implementation to
[gorilla/mux](https://github.com/gorilla/mux) -- unlike built-in it does
not panic if you set two handlers on the same part (instead the earliest
one wins). It also does not have any dependencies.
- Move acceptance/selftest into acceptance/selftest/basic and added
acceptance/selftest/server that demoes server override.
- Rewrite server set up to ensure that env vars and replacements are set
up correctly. Previously replacements for DATABRICKS_HOST referred to
default server, not to the custom server.
- Avoid calling CurrentUser.Me() in the local case. This allows
overriding /api/2.0/preview/scim/v2/Me, which we use in some tests (e.g.
bundle/templates-machinery/helpers-error). Previously the test passed
because CurrentUser.Me() was calling default server which is incorrect
but it happened to make the tests pass.
- The default server is now available on DATABRICKS_DEFAULT_HOST env
var.
- Rewrite "not found" handler in local test to handle error better (do
not raise http500 when header is already written).

## Tests
New acceptance test selftest/server specifically tests that both custom
and default handlers are available in a single test.
2025-02-11 15:03:41 +00:00
Denis Bilenko 272ce61302
acc: Fix singleTest option to support forward slashes (#2336)
The filtering of tests needs to see forward slashes otherwise it is
OS-dependent.

I've also switched to filepath.ToSlash but it should be a no-op.
2025-02-11 15:26:46 +01:00
Denis Bilenko 878fa80322
acc: Fix RecordRequests to support requests without body (#2333)
## Changes
Do not paste request body into output if it's not a valid JSON.

## Tests
While working on #2334 I found that if I try to record a test that calls
/api/2.0/preview/scim/v2/Me which has no request body, it crashes.
2025-02-11 10:50:52 +00:00
Denis Bilenko 8d849fe868
acc: Disable custom server on CLOUD_ENV (#2332)
We're not using local server when CLOUD_ENV is enabled, no need to set
up a custom one.
2025-02-11 10:37:48 +00:00
Shreyas Goenka e9b0afb337
- 2025-01-03 00:03:16 +05:30
Shreyas Goenka ee9499bc68
use write testutil 2025-01-03 00:01:55 +05:30
Shreyas Goenka f4623ebbb9
cleanup 2025-01-02 23:56:58 +05:30
Shreyas Goenka ac37ca0d98
- 2025-01-02 18:54:35 +05:30
Shreyas Goenka 7ab9fb7cec
simplify copy 2025-01-02 18:53:14 +05:30
Shreyas Goenka 9552131a2a
lint 2025-01-02 18:11:35 +05:30
Shreyas Goenka 583637aed6
lint 2025-01-02 18:08:14 +05:30
Shreyas Goenka 6991dea00b
make content length work 2025-01-02 17:59:53 +05:30
Shreyas Goenka 1e2545eedf
merge 2025-01-02 17:21:17 +05:30
Shreyas Goenka f70c47253e
calculate content length before upload 2025-01-02 17:20:22 +05:30
Shreyas Goenka 890b48f70d
Reapply "add streaming uploads"
This reverts commit 8ec1e0746d.
2025-01-02 16:55:56 +05:30
Shreyas Goenka 8ec1e0746d
Revert "add streaming uploads"
This reverts commit 95f41b1f30.
2025-01-02 12:11:38 +05:30
Shreyas Goenka 95f41b1f30
add streaming uploads 2025-01-02 11:59:26 +05:30
Shreyas Goenka 92e97ad413
- 2024-12-31 17:11:39 +05:30
Shreyas Goenka 69fdd9736b
reduce diff 2024-12-31 17:10:19 +05:30
Shreyas Goenka ee8017357e
fix size 2024-12-31 16:07:53 +05:30
Shreyas Goenka 7084392a0f
cleanup code 2024-12-31 13:40:10 +05:30
Shreyas Goenka be62ead7be
lint 2024-12-31 13:31:38 +05:30
Shreyas Goenka cf51636faa
overwrite fix 2024-12-31 13:28:36 +05:30
Shreyas Goenka 09bf4fa90c
add unit test 2024-12-31 12:53:25 +05:30
Shreyas Goenka 9d8ba099ba
fix fd lingering' 2024-12-31 12:47:47 +05:30
Shreyas Goenka 932aeee349
ignore linter 2024-12-31 12:46:48 +05:30
Shreyas Goenka 63e599ccb2
added integration test 2024-12-31 12:31:37 +05:30
Shreyas Goenka 0ce50fadf0
add unit test 2024-12-31 12:09:24 +05:30
Shreyas Goenka 2717aca239
Merge remote-tracking branch 'origin' into multipart-dbfs 2024-12-31 10:27:45 +05:30
Shreyas Goenka 78b9788bc6
some cleanup 2024-12-30 22:43:45 +05:30
Shreyas Goenka da46c142e0
Merge remote-tracking branch 'origin' into multipart-dbfs 2024-12-30 21:20:17 +05:30
Shreyas Goenka f690f0a342
Merge remote-tracking branch 'origin' into multipart-dbfs 2024-12-05 15:59:07 +05:30
Shreyas Goenka 04cae2f8cf
Merge remote-tracking branch 'origin' into multipart-dbfs 2024-12-04 20:19:59 +01:00
Shreyas Goenka 4b484fdcdc
todo 2024-12-03 00:07:14 +01:00
Shreyas Goenka 06af01c8f6
refactor to make tests easier 2024-12-03 00:05:04 +01:00
Shreyas Goenka 91a2dfa0ed
- 2024-12-02 23:49:32 +01:00
Shreyas Goenka c4cea1aeff
Use the `/dbfs/put` API endpoint to upload smaller DBFS files 2024-12-02 23:47:31 +01:00
32 changed files with 711 additions and 109 deletions

4
NOTICE
View File

@ -114,3 +114,7 @@ dario.cat/mergo
Copyright (c) 2013 Dario Castañé. All rights reserved.
Copyright (c) 2012 The Go Authors. All rights reserved.
https://github.com/darccio/mergo/blob/master/LICENSE
https://github.com/gorilla/mux
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
https://github.com/gorilla/mux/blob/main/LICENSE

View File

@ -11,6 +11,7 @@ import (
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"slices"
"sort"
@ -26,6 +27,7 @@ import (
"github.com/databricks/cli/libs/testdiff"
"github.com/databricks/cli/libs/testserver"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/stretchr/testify/require"
)
@ -72,7 +74,8 @@ func TestInprocessMode(t *testing.T) {
if InprocessMode {
t.Skip("Already tested by TestAccept")
}
require.Equal(t, 1, testAccept(t, true, "selftest"))
require.Equal(t, 1, testAccept(t, true, "selftest/basic"))
require.Equal(t, 1, testAccept(t, true, "selftest/server"))
}
func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
@ -118,14 +121,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
uvCache := getUVDefaultCacheDir(t)
t.Setenv("UV_CACHE_DIR", uvCache)
ctx := context.Background()
cloudEnv := os.Getenv("CLOUD_ENV")
if cloudEnv == "" {
defaultServer := testserver.New(t)
AddHandlers(defaultServer)
// Redirect API access to local server:
t.Setenv("DATABRICKS_HOST", defaultServer.URL)
t.Setenv("DATABRICKS_DEFAULT_HOST", defaultServer.URL)
homeDir := t.TempDir()
// Do not read user's ~/.databrickscfg
@ -148,27 +149,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
// do it last so that full paths match first:
repls.SetPath(buildDir, "[BUILD_DIR]")
var config databricks.Config
if cloudEnv == "" {
// use fake token for local tests
config = databricks.Config{Token: "dbapi1234"}
} else {
// non-local tests rely on environment variables
config = databricks.Config{}
}
workspaceClient, err := databricks.NewWorkspaceClient(&config)
require.NoError(t, err)
user, err := workspaceClient.CurrentUser.Me(ctx)
require.NoError(t, err)
require.NotNil(t, user)
testdiff.PrepareReplacementsUser(t, &repls, *user)
testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient)
testdiff.PrepareReplacementsUUID(t, &repls)
testdiff.PrepareReplacementsDevVersion(t, &repls)
testdiff.PrepareReplacementSdkVersion(t, &repls)
testdiff.PrepareReplacementsGoVersion(t, &repls)
repls.Repls = append(repls.Repls, testdiff.Replacement{Old: regexp.MustCompile("dbapi[0-9a-f]+"), New: "[DATABRICKS_TOKEN]"})
testDirs := getTests(t)
require.NotEmpty(t, testDirs)
@ -180,8 +166,7 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
}
for _, dir := range testDirs {
testName := strings.ReplaceAll(dir, "\\", "/")
t.Run(testName, func(t *testing.T) {
t.Run(dir, func(t *testing.T) {
if !InprocessMode {
t.Parallel()
}
@ -203,7 +188,8 @@ func getTests(t *testing.T) []string {
name := filepath.Base(path)
if name == EntryPointScript {
// Presence of 'script' marks a test case in this directory
testDirs = append(testDirs, filepath.Dir(path))
testName := filepath.ToSlash(filepath.Dir(path))
testDirs = append(testDirs, testName)
}
return nil
})
@ -239,7 +225,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
}
repls.SetPathWithParents(tmpDir, "[TMPDIR]")
repls.Repls = append(repls.Repls, config.Repls...)
scriptContents := readMergedScriptContents(t, dir)
testutil.WriteFile(t, filepath.Join(tmpDir, EntryPointScript), scriptContents)
@ -253,38 +238,79 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
cmd := exec.Command(args[0], args[1:]...)
cmd.Env = os.Environ()
var workspaceClient *databricks.WorkspaceClient
var user iam.User
// Start a new server with a custom configuration if the acceptance test
// specifies a custom server stubs.
var server *testserver.Server
// Start a new server for this test if either:
// 1. A custom server spec is defined in the test configuration.
// 2. The test is configured to record requests and assert on them. We need
// a duplicate of the default server to record requests because the default
// server otherwise is a shared resource.
if len(config.Server) > 0 || config.RecordRequests {
server = testserver.New(t)
server.RecordRequests = config.RecordRequests
server.IncludeRequestHeaders = config.IncludeRequestHeaders
if cloudEnv == "" {
// Start a new server for this test if either:
// 1. A custom server spec is defined in the test configuration.
// 2. The test is configured to record requests and assert on them. We need
// a duplicate of the default server to record requests because the default
// server otherwise is a shared resource.
// If no custom server stubs are defined, add the default handlers.
if len(config.Server) == 0 {
databricksLocalHost := os.Getenv("DATABRICKS_DEFAULT_HOST")
if len(config.Server) > 0 || config.RecordRequests {
server = testserver.New(t)
server.RecordRequests = config.RecordRequests
server.IncludeRequestHeaders = config.IncludeRequestHeaders
for _, stub := range config.Server {
require.NotEmpty(t, stub.Pattern)
items := strings.Split(stub.Pattern, " ")
require.Len(t, items, 2)
server.Handle(items[0], items[1], func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) {
statusCode := http.StatusOK
if stub.Response.StatusCode != 0 {
statusCode = stub.Response.StatusCode
}
return stub.Response.Body, statusCode
})
}
// The earliest handlers take precedence, add default handlers last
AddHandlers(server)
databricksLocalHost = server.URL
}
for _, stub := range config.Server {
require.NotEmpty(t, stub.Pattern)
server.Handle(stub.Pattern, func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) {
statusCode := http.StatusOK
if stub.Response.StatusCode != 0 {
statusCode = stub.Response.StatusCode
}
return stub.Response.Body, statusCode
})
// Each local test should use a new token that will result into a new fake workspace,
// so that test don't interfere with each other.
tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "")
config := databricks.Config{
Host: databricksLocalHost,
Token: "dbapi" + tokenSuffix,
}
cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+server.URL)
workspaceClient, err = databricks.NewWorkspaceClient(&config)
require.NoError(t, err)
cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+config.Host)
cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+config.Token)
// For the purposes of replacements, use testUser.
// Note, users might have overriden /api/2.0/preview/scim/v2/Me but that should not affect the replacement:
user = testUser
} else {
// Use whatever authentication mechanism is configured by the test runner.
workspaceClient, err = databricks.NewWorkspaceClient(&databricks.Config{})
require.NoError(t, err)
pUser, err := workspaceClient.CurrentUser.Me(context.Background())
require.NoError(t, err, "Failed to get current user")
user = *pUser
}
testdiff.PrepareReplacementsUser(t, &repls, user)
testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient)
// Must be added PrepareReplacementsUser, otherwise conflicts with [USERNAME]
testdiff.PrepareReplacementsUUID(t, &repls)
// User replacements come last:
repls.Repls = append(repls.Repls, config.Repls...)
if coverDir != "" {
// Creating individual coverage directory for each test, because writing to the same one
// results in sporadic failures like this one (only if tests are running in parallel):
@ -295,15 +321,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
cmd.Env = append(cmd.Env, "GOCOVERDIR="+coverDir)
}
// Each local test should use a new token that will result into a new fake workspace,
// so that test don't interfere with each other.
if cloudEnv == "" {
tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "")
token := "dbapi" + tokenSuffix
cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+token)
repls.Set(token, "[DATABRICKS_TOKEN]")
}
// Write combined output to a file
out, err := os.Create(filepath.Join(tmpDir, "output.txt"))
require.NoError(t, err)
@ -320,7 +337,7 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
for _, req := range server.Requests {
reqJson, err := json.MarshalIndent(req, "", " ")
require.NoError(t, err)
require.NoErrorf(t, err, "Failed to indent: %#v", req)
reqJsonWithRepls := repls.Replace(string(reqJson))
_, err = f.WriteString(reqJsonWithRepls + "\n")

View File

@ -13,13 +13,13 @@
=== Inside the bundle, profile flag not matching bundle host. Badness: should use profile from flag instead and not fail
>>> errcode [CLI] current-user me -p profile_name
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL]
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET]
Exit code: 1
=== Inside the bundle, target and not matching profile
>>> errcode [CLI] current-user me -t dev -p profile_name
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL]
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET]
Exit code: 1

View File

@ -5,4 +5,8 @@ Badness = "When -p flag is used inside the bundle folder for any CLI commands, C
# This is a workaround to replace DATABRICKS_URL with DATABRICKS_HOST
[[Repls]]
Old='DATABRICKS_HOST'
New='DATABRICKS_URL'
New='DATABRICKS_TARGET'
[[Repls]]
Old='DATABRICKS_URL'
New='DATABRICKS_TARGET'

View File

@ -0,0 +1,12 @@
{
"headers": {
"Authorization": [
"Basic [ENCODED_AUTH]"
],
"User-Agent": [
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/basic"
]
},
"method": "GET",
"path": "/api/2.0/preview/scim/v2/Me"
}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,8 @@
# Unset the token which is configured by default
# in acceptance tests
export DATABRICKS_TOKEN=""
export DATABRICKS_USERNAME=username
export DATABRICKS_PASSWORD=password
$CLI current-user me

View File

@ -0,0 +1,4 @@
# "username:password" in base64 is dXNlcm5hbWU6cGFzc3dvcmQ=, expect to see this in Authorization header
[[Repls]]
Old = "dXNlcm5hbWU6cGFzc3dvcmQ="
New = "[ENCODED_AUTH]"

View File

@ -0,0 +1,34 @@
{
"headers": {
"User-Agent": [
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS]"
]
},
"method": "GET",
"path": "/oidc/.well-known/oauth-authorization-server"
}
{
"headers": {
"Authorization": [
"Basic [ENCODED_AUTH]"
],
"User-Agent": [
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS]"
]
},
"method": "POST",
"path": "/oidc/v1/token",
"raw_body": "grant_type=client_credentials\u0026scope=all-apis"
}
{
"headers": {
"Authorization": [
"Bearer oauth-token"
],
"User-Agent": [
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/oauth-m2m"
]
},
"method": "GET",
"path": "/api/2.0/preview/scim/v2/Me"
}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,8 @@
# Unset the token which is configured by default
# in acceptance tests
export DATABRICKS_TOKEN=""
export DATABRICKS_CLIENT_ID=client_id
export DATABRICKS_CLIENT_SECRET=client_secret
$CLI current-user me

View File

@ -0,0 +1,5 @@
# "client_id:client_secret" in base64 is Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=, expect to
# see this in Authorization header
[[Repls]]
Old = "Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ="
New = "[ENCODED_AUTH]"

View File

@ -0,0 +1,12 @@
{
"headers": {
"Authorization": [
"Bearer dapi1234"
],
"User-Agent": [
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/pat"
]
},
"method": "GET",
"path": "/api/2.0/preview/scim/v2/Me"
}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,3 @@
export DATABRICKS_TOKEN=dapi1234
$CLI current-user me

View File

@ -0,0 +1,20 @@
LocalOnly = true
RecordRequests = true
IncludeRequestHeaders = ["Authorization", "User-Agent"]
[[Repls]]
Old = '(linux|darwin|windows)'
New = '[OS]'
[[Repls]]
Old = " upstream/[A-Za-z0-9.-]+"
New = ""
[[Repls]]
Old = " upstream-version/[A-Za-z0-9.-]+"
New = ""
[[Repls]]
Old = " cicd/[A-Za-z0-9.-]+"
New = ""

View File

@ -14,11 +14,7 @@ import (
func StartCmdServer(t *testing.T) *testserver.Server {
server := testserver.New(t)
// {$} is a wildcard that only matches the end of the URL. We explicitly use
// /{$} to disambiguate it from the generic handler for '/' which is used to
// identify unhandled API endpoints in the test server.
server.Handle("/{$}", func(w *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/", func(_ *testserver.FakeWorkspace, r *http.Request) (any, int) {
q := r.URL.Query()
args := strings.Split(q.Get("args"), " ")

View File

@ -0,0 +1,8 @@
{
"method": "GET",
"path": "/api/2.0/preview/scim/v2/Me"
}
{
"method": "GET",
"path": "/custom/endpoint"
}

View File

@ -0,0 +1,15 @@
>>> curl -s [DATABRICKS_URL]/api/2.0/preview/scim/v2/Me
{
"id": "[USERID]",
"userName": "[USERNAME]"
}
>>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param
HTTP/1.1 201 Created
Content-Type: application/json
Date: (redacted)
Content-Length: (redacted)
custom
---
response

View File

@ -0,0 +1,2 @@
trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me
trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param

View File

@ -0,0 +1,18 @@
LocalOnly = true
RecordRequests = true
[[Server]]
Pattern = "GET /custom/endpoint"
Response.Body = '''custom
---
response
'''
Response.StatusCode = 201
[[Repls]]
Old = 'Date: .*'
New = 'Date: (redacted)'
[[Repls]]
Old = 'Content-Length: [0-9]*'
New = 'Content-Length: (redacted)'

View File

@ -8,6 +8,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/gorilla/mux"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
@ -16,8 +17,13 @@ import (
"github.com/databricks/databricks-sdk-go/service/workspace"
)
var testUser = iam.User{
Id: "1000012345",
UserName: "tester@databricks.com",
}
func AddHandlers(server *testserver.Server) {
server.Handle("GET /api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return compute.ListPoliciesResponse{
Policies: []compute.Policy{
{
@ -32,7 +38,7 @@ func AddHandlers(server *testserver.Server) {
}, http.StatusOK
})
server.Handle("GET /api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return compute.ListInstancePools{
InstancePools: []compute.InstancePoolAndStats{
{
@ -43,7 +49,7 @@ func AddHandlers(server *testserver.Server) {
}, http.StatusOK
})
server.Handle("GET /api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{
{
@ -58,20 +64,17 @@ func AddHandlers(server *testserver.Server) {
}, http.StatusOK
})
server.Handle("GET /api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return iam.User{
Id: "1000012345",
UserName: "tester@databricks.com",
}, http.StatusOK
server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return testUser, http.StatusOK
})
server.Handle("GET /api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
path := r.URL.Query().Get("path")
return fakeWorkspace.WorkspaceGetStatus(path)
})
server.Handle("POST /api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("POST", "/api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
request := workspace.Mkdirs{}
decoder := json.NewDecoder(r.Body)
@ -83,13 +86,13 @@ func AddHandlers(server *testserver.Server) {
return fakeWorkspace.WorkspaceMkdirs(request)
})
server.Handle("GET /api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
path := r.URL.Query().Get("path")
return fakeWorkspace.WorkspaceExport(path)
})
server.Handle("POST /api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("POST", "/api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
path := r.URL.Query().Get("path")
recursiveStr := r.URL.Query().Get("recursive")
var recursive bool
@ -103,8 +106,9 @@ func AddHandlers(server *testserver.Server) {
return fakeWorkspace.WorkspaceDelete(path, recursive)
})
server.Handle("POST /api/2.0/workspace-files/import-file/{path}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
path := r.PathValue("path")
server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
vars := mux.Vars(r)
path := vars["path"]
body := new(bytes.Buffer)
_, err := body.ReadFrom(r.Body)
@ -115,14 +119,15 @@ func AddHandlers(server *testserver.Server) {
return fakeWorkspace.WorkspaceFilesImportFile(path, body.Bytes())
})
server.Handle("GET /api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return catalog.MetastoreAssignment{
DefaultCatalogName: "main",
}, http.StatusOK
})
server.Handle("GET /api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
objectId := r.PathValue("objectId")
server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
vars := mux.Vars(r)
objectId := vars["objectId"]
return workspace.WorkspaceObjectPermissions{
ObjectId: objectId,
@ -140,7 +145,7 @@ func AddHandlers(server *testserver.Server) {
}, http.StatusOK
})
server.Handle("POST /api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("POST", "/api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
request := jobs.CreateJob{}
decoder := json.NewDecoder(r.Body)
@ -152,15 +157,31 @@ func AddHandlers(server *testserver.Server) {
return fakeWorkspace.JobsCreate(request)
})
server.Handle("GET /api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
jobId := r.URL.Query().Get("job_id")
return fakeWorkspace.JobsGet(jobId)
})
server.Handle("GET /api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
server.Handle("GET", "/api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return fakeWorkspace.JobsList()
})
server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return map[string]string{
"authorization_endpoint": server.URL + "oidc/v1/authorize",
"token_endpoint": server.URL + "/oidc/v1/token",
}, http.StatusOK
})
server.Handle("POST", "/oidc/v1/token", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
return map[string]string{
"access_token": "oauth-token",
"expires_in": "3600",
"scope": "all-apis",
"token_type": "Bearer",
}, http.StatusOK
})
}
func internalError(err error) (any, int) {

1
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/databricks/databricks-sdk-go v0.57.0 // Apache 2.0
github.com/fatih/color v1.18.0 // MIT
github.com/google/uuid v1.6.0 // BSD-3-Clause
github.com/gorilla/mux v1.8.1 // BSD 3-Clause
github.com/hashicorp/go-version v1.7.0 // MPL 2.0
github.com/hashicorp/hc-install v0.9.1 // MPL 2.0
github.com/hashicorp/terraform-exec v0.22.0 // MPL 2.0

2
go.sum generated
View File

@ -97,6 +97,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg=
github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=

View File

@ -6,7 +6,9 @@ import (
"encoding/json"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"testing"
@ -887,3 +889,80 @@ func TestWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
})
}
}
func TestDbfsFilerForStreamingUploads(t *testing.T) {
ctx := context.Background()
f, _ := setupDbfsFiler(t)
// Set MaxDbfsPutFileSize to 1 to force streaming uploads
prevV := filer.MaxDbfsPutFileSize
filer.MaxDbfsPutFileSize = 1
t.Cleanup(func() {
filer.MaxDbfsPutFileSize = prevV
})
// Write a file to local disk.
tmpDir := t.TempDir()
testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar")
fd, err := os.Open(filepath.Join(tmpDir, "foo.txt"))
require.NoError(t, err)
defer fd.Close()
// Write a file with streaming upload
err = f.Write(ctx, "foo.txt", fd)
require.NoError(t, err)
// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar")
// Overwrite the file with streaming upload, and fail
err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo"))
require.ErrorIs(t, err, fs.ErrExist)
// Overwrite the file with streaming upload, and succeed
err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo"), filer.OverwriteIfExists)
require.NoError(t, err)
// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo")
}
func TestDbfsFilerForPutUploads(t *testing.T) {
ctx := context.Background()
f, _ := setupDbfsFiler(t)
// Write a file to local disk.
tmpDir := t.TempDir()
testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar")
testutil.WriteFile(t, filepath.Join(tmpDir, "bar.txt"), "barfoo")
fdFoo, err := os.Open(filepath.Join(tmpDir, "foo.txt"))
require.NoError(t, err)
defer fdFoo.Close()
fdBar, err := os.Open(filepath.Join(tmpDir, "bar.txt"))
require.NoError(t, err)
defer fdBar.Close()
// Write a file with PUT upload
err = f.Write(ctx, "foo.txt", fdFoo)
require.NoError(t, err)
// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar")
// Try to overwrite the file, and fail.
err = f.Write(ctx, "foo.txt", fdBar)
require.ErrorIs(t, err, fs.ErrExist)
// Reset the file descriptor.
_, err = fdBar.Seek(0, io.SeekStart)
require.NoError(t, err)
// Overwrite the file with OverwriteIfExists flag
err = f.Write(ctx, "foo.txt", fdBar, filer.OverwriteIfExists)
require.NoError(t, err)
// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo")
}

View File

@ -1,11 +1,15 @@
package filer
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"mime/multipart"
"net/http"
"os"
"path"
"slices"
"sort"
@ -14,6 +18,7 @@ import (
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/service/files"
)
@ -63,33 +68,142 @@ func (info dbfsFileInfo) Sys() any {
return info.fi
}
// Interface to allow mocking of the Databricks API client.
type databricksClient interface {
Do(ctx context.Context, method, path string, headers map[string]string,
requestBody, responseBody any, visitors ...func(*http.Request) error) error
}
// DbfsClient implements the [Filer] interface for the DBFS backend.
type DbfsClient struct {
workspaceClient *databricks.WorkspaceClient
apiClient databricksClient
// File operations will be relative to this path.
root WorkspaceRootPath
}
func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
apiClient, err := client.New(w.Config)
if err != nil {
return nil, fmt.Errorf("failed to create API client: %w", err)
}
return &DbfsClient{
workspaceClient: w,
apiClient: apiClient,
root: NewWorkspaceRootPath(root),
}, nil
}
// The PUT API for DBFS requires setting the content length header beforehand in the HTTP
// request.
func contentLength(path, overwriteField string, file *os.File) (int64, error) {
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
err := writer.WriteField("path", path)
if err != nil {
return 0, fmt.Errorf("failed to write field path field in multipart form: %w", err)
}
err = writer.WriteField("overwrite", overwriteField)
if err != nil {
return 0, fmt.Errorf("failed to write field overwrite field in multipart form: %w", err)
}
_, err = writer.CreateFormFile("contents", "")
if err != nil {
return 0, fmt.Errorf("failed to write contents field in multipart form: %w", err)
}
err = writer.Close()
if err != nil {
return 0, fmt.Errorf("failed to close multipart form writer: %w", err)
}
stat, err := file.Stat()
if err != nil {
return 0, fmt.Errorf("failed to stat file %s: %w", path, err)
}
return int64(buf.Len()) + stat.Size(), nil
}
func contentLengthVisitor(path, overwriteField string, file *os.File) func(*http.Request) error {
return func(r *http.Request) error {
cl, err := contentLength(path, overwriteField, file)
if err != nil {
return fmt.Errorf("failed to calculate content length: %w", err)
}
r.ContentLength = cl
return nil
}
}
func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error {
overwriteField := "False"
if overwrite {
overwriteField = "True"
}
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
go func() {
defer pw.Close()
err := writer.WriteField("path", path)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field path field in multipart form: %w", err))
return
}
err = writer.WriteField("overwrite", overwriteField)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipart form: %w", err))
return
}
contents, err := writer.CreateFormFile("contents", "")
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write contents field in multipart form: %w", err))
return
}
_, err = io.Copy(contents, file)
if err != nil {
pw.CloseWithError(fmt.Errorf("error while streaming file to dbfs: %w", err))
return
}
err = writer.Close()
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to close multipart form writer: %w", err))
return
}
}()
// Request bodies of Content-Type multipart/form-data are not supported by
// the Go SDK directly for DBFS. So we use the Do method directly.
err := w.apiClient.Do(ctx,
http.MethodPost,
"/api/2.0/dbfs/put",
map[string]string{"Content-Type": writer.FormDataContentType()},
pr,
nil,
contentLengthVisitor(path, overwriteField, file))
var aerr *apierr.APIError
if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {
return FileAlreadyExistsError{path}
}
return err
}
// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded
// using the /dbfs/put API. If the file is larger than this limit, the streaming
// API (/dbfs/create and /dbfs/add-block) will be used instead.
var MaxDbfsPutFileSize int64 = 2 * 1024 * 1024 * 1024
func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error {
absPath, err := w.root.Join(name)
if err != nil {
return err
}
fileMode := files.FileModeWrite
if slices.Contains(mode, OverwriteIfExists) {
fileMode |= files.FileModeOverwrite
}
// Issue info call before write because it automatically creates parent directories.
//
// For discussion: we could decide this is actually convenient, remove the call below,
@ -114,7 +228,36 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
}
}
handle, err := w.workspaceClient.Dbfs.Open(ctx, absPath, fileMode)
localFile, ok := reader.(*os.File)
// If the source is not a local file, we'll always use the streaming API endpoint.
if !ok {
return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), reader)
}
stat, err := localFile.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
}
// If the source is a local file, but is too large then we'll use the streaming API endpoint.
if stat.Size() > MaxDbfsPutFileSize {
return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
}
// Use the /dbfs/put API when the file is on the local filesystem
// and is small enough. This is the most common case when users use the
// `databricks fs cp` command.
return w.putFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
}
func (w *DbfsClient) streamFile(ctx context.Context, path string, overwrite bool, reader io.Reader) error {
fileMode := files.FileModeWrite
if overwrite {
fileMode |= files.FileModeOverwrite
}
handle, err := w.workspaceClient.Dbfs.Open(ctx, path, fileMode)
if err != nil {
var aerr *apierr.APIError
if !errors.As(err, &aerr) {
@ -124,7 +267,7 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
// This API returns a 400 if the file already exists.
if aerr.StatusCode == http.StatusBadRequest {
if aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {
return FileAlreadyExistsError{absPath}
return FileAlreadyExistsError{path}
}
}
@ -136,7 +279,6 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
if err == nil {
err = cerr
}
return err
}

View File

@ -0,0 +1,155 @@
package filer
import (
"context"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/files"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockDbfsApiClient struct {
t testutil.TestingT
isCalled bool
}
func (m *mockDbfsApiClient) Do(ctx context.Context, method, path string,
headers map[string]string, request, response any,
visitors ...func(*http.Request) error,
) error {
m.isCalled = true
require.Equal(m.t, "POST", method)
require.Equal(m.t, "/api/2.0/dbfs/put", path)
require.Contains(m.t, headers["Content-Type"], "multipart/form-data; boundary=")
contents, err := io.ReadAll(request.(io.Reader))
require.NoError(m.t, err)
require.Contains(m.t, string(contents), "hello world")
return nil
}
func TestDbfsClientForSmallFiles(t *testing.T) {
// write file to local disk
tmp := t.TempDir()
localPath := filepath.Join(tmp, "hello.txt")
testutil.WriteFile(t, localPath, "hello world")
// setup DBFS client with mocks
m := mocks.NewMockWorkspaceClient(t)
mockApiClient := &mockDbfsApiClient{t: t}
dbfsClient := DbfsClient{
apiClient: mockApiClient,
workspaceClient: m.WorkspaceClient,
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
}
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
// write file to DBFS
fd, err := os.Open(localPath)
require.NoError(t, err)
defer fd.Close()
err = dbfsClient.Write(context.Background(), "hello.txt", fd)
require.NoError(t, err)
// verify mock API client is called
require.True(t, mockApiClient.isCalled)
}
type mockDbfsHandle struct {
builder strings.Builder
}
func (h *mockDbfsHandle) Read(data []byte) (n int, err error) { return 0, nil }
func (h *mockDbfsHandle) Close() error { return nil }
func (h *mockDbfsHandle) WriteTo(w io.Writer) (n int64, err error) { return 0, nil }
func (h *mockDbfsHandle) ReadFrom(r io.Reader) (n int64, err error) {
b, err := io.ReadAll(r)
if err != nil {
return 0, err
}
num, err := h.builder.Write(b)
return int64(num), err
}
func (h *mockDbfsHandle) Write(data []byte) (n int, err error) {
return h.builder.Write(data)
}
func TestDbfsClientForLargerFiles(t *testing.T) {
// write file to local disk
tmp := t.TempDir()
localPath := filepath.Join(tmp, "hello.txt")
testutil.WriteFile(t, localPath, "hello world")
// Modify the max file size to 1 byte to simulate
// a large file that needs to be uploaded in chunks.
oldV := MaxDbfsPutFileSize
MaxDbfsPutFileSize = 1
t.Cleanup(func() {
MaxDbfsPutFileSize = oldV
})
// setup DBFS client with mocks
m := mocks.NewMockWorkspaceClient(t)
mockApiClient := &mockDbfsApiClient{t: t}
dbfsClient := DbfsClient{
apiClient: mockApiClient,
workspaceClient: m.WorkspaceClient,
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
}
h := &mockDbfsHandle{}
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil)
// write file to DBFS
fd, err := os.Open(localPath)
require.NoError(t, err)
defer fd.Close()
err = dbfsClient.Write(context.Background(), "hello.txt", fd)
require.NoError(t, err)
// verify mock API client is NOT called
require.False(t, mockApiClient.isCalled)
// verify the file content was written to the mock handle
assert.Equal(t, "hello world", h.builder.String())
}
func TestDbfsClientForNonLocalFiles(t *testing.T) {
// setup DBFS client with mocks
m := mocks.NewMockWorkspaceClient(t)
mockApiClient := &mockDbfsApiClient{t: t}
dbfsClient := DbfsClient{
apiClient: mockApiClient,
workspaceClient: m.WorkspaceClient,
root: NewWorkspaceRootPath("dbfs:/a/b/c"),
}
h := &mockDbfsHandle{}
m.GetMockDbfsAPI().EXPECT().GetStatusByPath(mock.Anything, "dbfs:/a/b/c").Return(nil, nil)
m.GetMockDbfsAPI().EXPECT().Open(mock.Anything, "dbfs:/a/b/c/hello.txt", files.FileModeWrite).Return(h, nil)
// write file to DBFS
err := dbfsClient.Write(context.Background(), "hello.txt", strings.NewReader("hello world"))
require.NoError(t, err)
// verify mock API client is NOT called
require.False(t, mockApiClient.isCalled)
// verify the file content was written to the mock handle
assert.Equal(t, "hello world", h.builder.String())
}

View File

@ -9,6 +9,8 @@ import (
"strings"
"sync"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/databricks/cli/internal/testutil"
@ -17,7 +19,7 @@ import (
type Server struct {
*httptest.Server
Mux *http.ServeMux
Router *mux.Router
t testutil.TestingT
@ -34,26 +36,25 @@ type Request struct {
Headers http.Header `json:"headers,omitempty"`
Method string `json:"method"`
Path string `json:"path"`
Body any `json:"body"`
Body any `json:"body,omitempty"`
RawBody string `json:"raw_body,omitempty"`
}
func New(t testutil.TestingT) *Server {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
router := mux.NewRouter()
server := httptest.NewServer(router)
t.Cleanup(server.Close)
s := &Server{
Server: server,
Mux: mux,
Router: router,
t: t,
mu: &sync.Mutex{},
fakeWorkspaces: map[string]*FakeWorkspace{},
}
// The server resolves conflicting handlers by using the one with higher
// specificity. This handler is the least specific, so it will be used as a
// fallback when no other handlers match.
s.Handle("/", func(fakeWorkspace *FakeWorkspace, r *http.Request) (any, int) {
// Set up the not found handler as fallback
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pattern := r.Method + " " + r.URL.Path
t.Errorf(`
@ -74,9 +75,22 @@ Response.StatusCode = <response status-code here>
`, pattern, pattern)
return apierr.APIError{
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
resp := apierr.APIError{
Message: "No stub found for pattern: " + pattern,
}, http.StatusNotImplemented
}
respBytes, err := json.Marshal(resp)
if err != nil {
t.Errorf("JSON encoding error: %s", err)
respBytes = []byte("{\"message\": \"JSON encoding error\"}")
}
if _, err := w.Write(respBytes); err != nil {
t.Errorf("Response write error: %s", err)
}
})
return s
@ -84,8 +98,8 @@ Response.StatusCode = <response status-code here>
type HandlerFunc func(fakeWorkspace *FakeWorkspace, req *http.Request) (resp any, statusCode int)
func (s *Server) Handle(pattern string, handler HandlerFunc) {
s.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
func (s *Server) Handle(method, path string, handler HandlerFunc) {
s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
// For simplicity we process requests sequentially. It's fast enough because
// we don't do any IO except reading and writing request/response bodies.
s.mu.Lock()
@ -119,13 +133,19 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) {
}
}
s.Requests = append(s.Requests, Request{
req := Request{
Headers: headers,
Method: r.Method,
Path: r.URL.Path,
Body: json.RawMessage(body),
})
}
if json.Valid(body) {
req.Body = json.RawMessage(body)
} else {
req.RawBody = string(body)
}
s.Requests = append(s.Requests, req)
}
w.Header().Set("Content-Type", "application/json")
@ -149,7 +169,7 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
}).Methods(method)
}
func getToken(r *http.Request) string {