diff --git a/NOTICE b/NOTICE index 4331a2a32..0b1d2da04 100644 --- a/NOTICE +++ b/NOTICE @@ -114,3 +114,7 @@ dario.cat/mergo Copyright (c) 2013 Dario Castañé. All rights reserved. Copyright (c) 2012 The Go Authors. All rights reserved. https://github.com/darccio/mergo/blob/master/LICENSE + +https://github.com/gorilla/mux +Copyright (c) 2023 The Gorilla Authors. All rights reserved. +https://github.com/gorilla/mux/blob/main/LICENSE diff --git a/acceptance/acceptance_test.go b/acceptance/acceptance_test.go index fce508498..117172f60 100644 --- a/acceptance/acceptance_test.go +++ b/acceptance/acceptance_test.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "path/filepath" + "regexp" "runtime" "slices" "sort" @@ -26,6 +27,7 @@ import ( "github.com/databricks/cli/libs/testdiff" "github.com/databricks/cli/libs/testserver" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/iam" "github.com/stretchr/testify/require" ) @@ -72,7 +74,8 @@ func TestInprocessMode(t *testing.T) { if InprocessMode { t.Skip("Already tested by TestAccept") } - require.Equal(t, 1, testAccept(t, true, "selftest")) + require.Equal(t, 1, testAccept(t, true, "selftest/basic")) + require.Equal(t, 1, testAccept(t, true, "selftest/server")) } func testAccept(t *testing.T, InprocessMode bool, singleTest string) int { @@ -118,14 +121,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int { uvCache := getUVDefaultCacheDir(t) t.Setenv("UV_CACHE_DIR", uvCache) - ctx := context.Background() cloudEnv := os.Getenv("CLOUD_ENV") if cloudEnv == "" { defaultServer := testserver.New(t) AddHandlers(defaultServer) - // Redirect API access to local server: - t.Setenv("DATABRICKS_HOST", defaultServer.URL) + t.Setenv("DATABRICKS_DEFAULT_HOST", defaultServer.URL) homeDir := t.TempDir() // Do not read user's ~/.databrickscfg @@ -148,27 +149,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int { // do it last so that full paths match first: repls.SetPath(buildDir, "[BUILD_DIR]") - var config databricks.Config - if cloudEnv == "" { - // use fake token for local tests - config = databricks.Config{Token: "dbapi1234"} - } else { - // non-local tests rely on environment variables - config = databricks.Config{} - } - workspaceClient, err := databricks.NewWorkspaceClient(&config) - require.NoError(t, err) - - user, err := workspaceClient.CurrentUser.Me(ctx) - require.NoError(t, err) - require.NotNil(t, user) - testdiff.PrepareReplacementsUser(t, &repls, *user) - testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient) - testdiff.PrepareReplacementsUUID(t, &repls) testdiff.PrepareReplacementsDevVersion(t, &repls) testdiff.PrepareReplacementSdkVersion(t, &repls) testdiff.PrepareReplacementsGoVersion(t, &repls) + repls.Repls = append(repls.Repls, testdiff.Replacement{Old: regexp.MustCompile("dbapi[0-9a-f]+"), New: "[DATABRICKS_TOKEN]"}) + testDirs := getTests(t) require.NotEmpty(t, testDirs) @@ -239,7 +225,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont } repls.SetPathWithParents(tmpDir, "[TMPDIR]") - repls.Repls = append(repls.Repls, config.Repls...) scriptContents := readMergedScriptContents(t, dir) testutil.WriteFile(t, filepath.Join(tmpDir, EntryPointScript), scriptContents) @@ -253,38 +238,79 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont cmd := exec.Command(args[0], args[1:]...) cmd.Env = os.Environ() + var workspaceClient *databricks.WorkspaceClient + var user iam.User + // Start a new server with a custom configuration if the acceptance test // specifies a custom server stubs. var server *testserver.Server - // Start a new server for this test if either: - // 1. A custom server spec is defined in the test configuration. - // 2. The test is configured to record requests and assert on them. We need - // a duplicate of the default server to record requests because the default - // server otherwise is a shared resource. - if cloudEnv == "" && (len(config.Server) > 0 || config.RecordRequests) { - server = testserver.New(t) - server.RecordRequests = config.RecordRequests - server.IncludeRequestHeaders = config.IncludeRequestHeaders + if cloudEnv == "" { + // Start a new server for this test if either: + // 1. A custom server spec is defined in the test configuration. + // 2. The test is configured to record requests and assert on them. We need + // a duplicate of the default server to record requests because the default + // server otherwise is a shared resource. - // If no custom server stubs are defined, add the default handlers. - if len(config.Server) == 0 { + databricksLocalHost := os.Getenv("DATABRICKS_DEFAULT_HOST") + + if len(config.Server) > 0 || config.RecordRequests { + server = testserver.New(t) + server.RecordRequests = config.RecordRequests + server.IncludeRequestHeaders = config.IncludeRequestHeaders + + for _, stub := range config.Server { + require.NotEmpty(t, stub.Pattern) + items := strings.Split(stub.Pattern, " ") + require.Len(t, items, 2) + server.Handle(items[0], items[1], func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) { + statusCode := http.StatusOK + if stub.Response.StatusCode != 0 { + statusCode = stub.Response.StatusCode + } + return stub.Response.Body, statusCode + }) + } + + // The earliest handlers take precedence, add default handlers last AddHandlers(server) + databricksLocalHost = server.URL } - for _, stub := range config.Server { - require.NotEmpty(t, stub.Pattern) - server.Handle(stub.Pattern, func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) { - statusCode := http.StatusOK - if stub.Response.StatusCode != 0 { - statusCode = stub.Response.StatusCode - } - return stub.Response.Body, statusCode - }) + // Each local test should use a new token that will result into a new fake workspace, + // so that test don't interfere with each other. + tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "") + config := databricks.Config{ + Host: databricksLocalHost, + Token: "dbapi" + tokenSuffix, } - cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+server.URL) + workspaceClient, err = databricks.NewWorkspaceClient(&config) + require.NoError(t, err) + + cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+config.Host) + cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+config.Token) + + // For the purposes of replacements, use testUser. + // Note, users might have overriden /api/2.0/preview/scim/v2/Me but that should not affect the replacement: + user = testUser + } else { + // Use whatever authentication mechanism is configured by the test runner. + workspaceClient, err = databricks.NewWorkspaceClient(&databricks.Config{}) + require.NoError(t, err) + pUser, err := workspaceClient.CurrentUser.Me(context.Background()) + require.NoError(t, err, "Failed to get current user") + user = *pUser } + testdiff.PrepareReplacementsUser(t, &repls, user) + testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient) + + // Must be added PrepareReplacementsUser, otherwise conflicts with [USERNAME] + testdiff.PrepareReplacementsUUID(t, &repls) + + // User replacements come last: + repls.Repls = append(repls.Repls, config.Repls...) + if coverDir != "" { // Creating individual coverage directory for each test, because writing to the same one // results in sporadic failures like this one (only if tests are running in parallel): @@ -295,15 +321,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont cmd.Env = append(cmd.Env, "GOCOVERDIR="+coverDir) } - // Each local test should use a new token that will result into a new fake workspace, - // so that test don't interfere with each other. - if cloudEnv == "" { - tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "") - token := "dbapi" + tokenSuffix - cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+token) - repls.Set(token, "[DATABRICKS_TOKEN]") - } - // Write combined output to a file out, err := os.Create(filepath.Join(tmpDir, "output.txt")) require.NoError(t, err) @@ -320,7 +337,7 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont for _, req := range server.Requests { reqJson, err := json.MarshalIndent(req, "", " ") - require.NoError(t, err) + require.NoErrorf(t, err, "Failed to indent: %#v", req) reqJsonWithRepls := repls.Replace(string(reqJson)) _, err = f.WriteString(reqJsonWithRepls + "\n") diff --git a/acceptance/auth/bundle_and_profile/output.txt b/acceptance/auth/bundle_and_profile/output.txt index 022b3148d..8d2584622 100644 --- a/acceptance/auth/bundle_and_profile/output.txt +++ b/acceptance/auth/bundle_and_profile/output.txt @@ -13,13 +13,13 @@ === Inside the bundle, profile flag not matching bundle host. Badness: should use profile from flag instead and not fail >>> errcode [CLI] current-user me -p profile_name -Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL] +Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET] Exit code: 1 === Inside the bundle, target and not matching profile >>> errcode [CLI] current-user me -t dev -p profile_name -Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL] +Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET] Exit code: 1 diff --git a/acceptance/auth/bundle_and_profile/test.toml b/acceptance/auth/bundle_and_profile/test.toml index b20190ca5..1a611ed95 100644 --- a/acceptance/auth/bundle_and_profile/test.toml +++ b/acceptance/auth/bundle_and_profile/test.toml @@ -5,4 +5,8 @@ Badness = "When -p flag is used inside the bundle folder for any CLI commands, C # This is a workaround to replace DATABRICKS_URL with DATABRICKS_HOST [[Repls]] Old='DATABRICKS_HOST' -New='DATABRICKS_URL' +New='DATABRICKS_TARGET' + +[[Repls]] +Old='DATABRICKS_URL' +New='DATABRICKS_TARGET' diff --git a/acceptance/cmd_server_test.go b/acceptance/cmd_server_test.go index c8a52f4cd..d3db06003 100644 --- a/acceptance/cmd_server_test.go +++ b/acceptance/cmd_server_test.go @@ -14,11 +14,7 @@ import ( func StartCmdServer(t *testing.T) *testserver.Server { server := testserver.New(t) - - // {$} is a wildcard that only matches the end of the URL. We explicitly use - // /{$} to disambiguate it from the generic handler for '/' which is used to - // identify unhandled API endpoints in the test server. - server.Handle("/{$}", func(w *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/", func(_ *testserver.FakeWorkspace, r *http.Request) (any, int) { q := r.URL.Query() args := strings.Split(q.Get("args"), " ") diff --git a/acceptance/selftest/out.hello.txt b/acceptance/selftest/basic/out.hello.txt similarity index 100% rename from acceptance/selftest/out.hello.txt rename to acceptance/selftest/basic/out.hello.txt diff --git a/acceptance/selftest/output.txt b/acceptance/selftest/basic/output.txt similarity index 100% rename from acceptance/selftest/output.txt rename to acceptance/selftest/basic/output.txt diff --git a/acceptance/selftest/script b/acceptance/selftest/basic/script similarity index 100% rename from acceptance/selftest/script rename to acceptance/selftest/basic/script diff --git a/acceptance/selftest/test.toml b/acceptance/selftest/basic/test.toml similarity index 100% rename from acceptance/selftest/test.toml rename to acceptance/selftest/basic/test.toml diff --git a/acceptance/selftest/server/out.requests.txt b/acceptance/selftest/server/out.requests.txt new file mode 100644 index 000000000..2cb8708ac --- /dev/null +++ b/acceptance/selftest/server/out.requests.txt @@ -0,0 +1,8 @@ +{ + "method": "GET", + "path": "/api/2.0/preview/scim/v2/Me" +} +{ + "method": "GET", + "path": "/custom/endpoint" +} diff --git a/acceptance/selftest/server/output.txt b/acceptance/selftest/server/output.txt new file mode 100644 index 000000000..f9e51caa9 --- /dev/null +++ b/acceptance/selftest/server/output.txt @@ -0,0 +1,15 @@ + +>>> curl -s [DATABRICKS_URL]/api/2.0/preview/scim/v2/Me +{ + "id": "[USERID]", + "userName": "[USERNAME]" +} +>>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param +HTTP/1.1 201 Created +Content-Type: application/json +Date: (redacted) +Content-Length: (redacted) + +custom +--- +response diff --git a/acceptance/selftest/server/script b/acceptance/selftest/server/script new file mode 100644 index 000000000..53e2c4b8a --- /dev/null +++ b/acceptance/selftest/server/script @@ -0,0 +1,2 @@ +trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me +trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param diff --git a/acceptance/selftest/server/test.toml b/acceptance/selftest/server/test.toml new file mode 100644 index 000000000..2531fb910 --- /dev/null +++ b/acceptance/selftest/server/test.toml @@ -0,0 +1,18 @@ +LocalOnly = true +RecordRequests = true + +[[Server]] +Pattern = "GET /custom/endpoint" +Response.Body = '''custom +--- +response +''' +Response.StatusCode = 201 + +[[Repls]] +Old = 'Date: .*' +New = 'Date: (redacted)' + +[[Repls]] +Old = 'Content-Length: [0-9]*' +New = 'Content-Length: (redacted)' diff --git a/acceptance/server_test.go b/acceptance/server_test.go index d21ab66e8..11d03c30b 100644 --- a/acceptance/server_test.go +++ b/acceptance/server_test.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/gorilla/mux" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" @@ -16,8 +17,13 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) +var testUser = iam.User{ + Id: "1000012345", + UserName: "tester@databricks.com", +} + func AddHandlers(server *testserver.Server) { - server.Handle("GET /api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { return compute.ListPoliciesResponse{ Policies: []compute.Policy{ { @@ -32,7 +38,7 @@ func AddHandlers(server *testserver.Server) { }, http.StatusOK }) - server.Handle("GET /api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { return compute.ListInstancePools{ InstancePools: []compute.InstancePoolAndStats{ { @@ -43,7 +49,7 @@ func AddHandlers(server *testserver.Server) { }, http.StatusOK }) - server.Handle("GET /api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { return compute.ListClustersResponse{ Clusters: []compute.ClusterDetails{ { @@ -58,20 +64,17 @@ func AddHandlers(server *testserver.Server) { }, http.StatusOK }) - server.Handle("GET /api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { - return iam.User{ - Id: "1000012345", - UserName: "tester@databricks.com", - }, http.StatusOK + server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + return testUser, http.StatusOK }) - server.Handle("GET /api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { path := r.URL.Query().Get("path") return fakeWorkspace.WorkspaceGetStatus(path) }) - server.Handle("POST /api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("POST", "/api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { request := workspace.Mkdirs{} decoder := json.NewDecoder(r.Body) @@ -83,13 +86,13 @@ func AddHandlers(server *testserver.Server) { return fakeWorkspace.WorkspaceMkdirs(request) }) - server.Handle("GET /api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { path := r.URL.Query().Get("path") return fakeWorkspace.WorkspaceExport(path) }) - server.Handle("POST /api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("POST", "/api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { path := r.URL.Query().Get("path") recursiveStr := r.URL.Query().Get("recursive") var recursive bool @@ -103,8 +106,9 @@ func AddHandlers(server *testserver.Server) { return fakeWorkspace.WorkspaceDelete(path, recursive) }) - server.Handle("POST /api/2.0/workspace-files/import-file/{path}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { - path := r.PathValue("path") + server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + vars := mux.Vars(r) + path := vars["path"] body := new(bytes.Buffer) _, err := body.ReadFrom(r.Body) @@ -115,14 +119,15 @@ func AddHandlers(server *testserver.Server) { return fakeWorkspace.WorkspaceFilesImportFile(path, body.Bytes()) }) - server.Handle("GET /api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { return catalog.MetastoreAssignment{ DefaultCatalogName: "main", }, http.StatusOK }) - server.Handle("GET /api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { - objectId := r.PathValue("objectId") + server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + vars := mux.Vars(r) + objectId := vars["objectId"] return workspace.WorkspaceObjectPermissions{ ObjectId: objectId, @@ -140,7 +145,7 @@ func AddHandlers(server *testserver.Server) { }, http.StatusOK }) - server.Handle("POST /api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("POST", "/api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { request := jobs.CreateJob{} decoder := json.NewDecoder(r.Body) @@ -152,13 +157,13 @@ func AddHandlers(server *testserver.Server) { return fakeWorkspace.JobsCreate(request) }) - server.Handle("GET /api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { jobId := r.URL.Query().Get("job_id") return fakeWorkspace.JobsGet(jobId) }) - server.Handle("GET /api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { + server.Handle("GET", "/api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { return fakeWorkspace.JobsList() }) } diff --git a/go.mod b/go.mod index c8b209edd..2e2505361 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/databricks/databricks-sdk-go v0.57.0 // Apache 2.0 github.com/fatih/color v1.18.0 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause + github.com/gorilla/mux v1.8.1 // BSD 3-Clause github.com/hashicorp/go-version v1.7.0 // MPL 2.0 github.com/hashicorp/hc-install v0.9.1 // MPL 2.0 github.com/hashicorp/terraform-exec v0.22.0 // MPL 2.0 diff --git a/go.sum b/go.sum index 0369fc2d9..fbf942148 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= diff --git a/libs/testserver/server.go b/libs/testserver/server.go index 577ef082c..cf4d5aca2 100644 --- a/libs/testserver/server.go +++ b/libs/testserver/server.go @@ -9,6 +9,8 @@ import ( "strings" "sync" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" "github.com/databricks/cli/internal/testutil" @@ -17,7 +19,7 @@ import ( type Server struct { *httptest.Server - Mux *http.ServeMux + Router *mux.Router t testutil.TestingT @@ -39,22 +41,20 @@ type Request struct { } func New(t testutil.TestingT) *Server { - mux := http.NewServeMux() - server := httptest.NewServer(mux) + router := mux.NewRouter() + server := httptest.NewServer(router) t.Cleanup(server.Close) s := &Server{ Server: server, - Mux: mux, + Router: router, t: t, mu: &sync.Mutex{}, fakeWorkspaces: map[string]*FakeWorkspace{}, } - // The server resolves conflicting handlers by using the one with higher - // specificity. This handler is the least specific, so it will be used as a - // fallback when no other handlers match. - s.Handle("/", func(fakeWorkspace *FakeWorkspace, r *http.Request) (any, int) { + // Set up the not found handler as fallback + router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pattern := r.Method + " " + r.URL.Path t.Errorf(` @@ -75,9 +75,22 @@ Response.StatusCode = `, pattern, pattern) - return apierr.APIError{ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotImplemented) + + resp := apierr.APIError{ Message: "No stub found for pattern: " + pattern, - }, http.StatusNotImplemented + } + + respBytes, err := json.Marshal(resp) + if err != nil { + t.Errorf("JSON encoding error: %s", err) + respBytes = []byte("{\"message\": \"JSON encoding error\"}") + } + + if _, err := w.Write(respBytes); err != nil { + t.Errorf("Response write error: %s", err) + } }) return s @@ -85,8 +98,8 @@ Response.StatusCode = type HandlerFunc func(fakeWorkspace *FakeWorkspace, req *http.Request) (resp any, statusCode int) -func (s *Server) Handle(pattern string, handler HandlerFunc) { - s.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { +func (s *Server) Handle(method, path string, handler HandlerFunc) { + s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { // For simplicity we process requests sequentially. It's fast enough because // we don't do any IO except reading and writing request/response bodies. s.mu.Lock() @@ -156,7 +169,7 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - }) + }).Methods(method) } func getToken(r *http.Request) string {