acc: Simplify writing handlers; support customing headers; fix read body bug

Handlers now receive testserver.Request and return any which could be
- nil (returns 404)
- string / []byte (returns it as is but sets content-type to json or test depending on content)
- object (encodes it as json and sets content-type to json)
- testserver.Response (full control over status, headers)

The config is now using the same testserver.Response struct as handlers, so the same logic applies there.

It is now possible to specify headers in test.toml.

This also fixes a bug with RecordRequest reading the body, not leaving it for the actual handler.
This commit is contained in:
Denis Bilenko 2025-02-11 17:13:42 +01:00
parent bfde3585b9
commit abea174a6a
12 changed files with 316 additions and 189 deletions

View File

@ -7,7 +7,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@ -267,12 +266,8 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
require.NotEmpty(t, stub.Pattern) require.NotEmpty(t, stub.Pattern)
items := strings.Split(stub.Pattern, " ") items := strings.Split(stub.Pattern, " ")
require.Len(t, items, 2) require.Len(t, items, 2)
server.Handle(items[0], items[1], func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) { server.Handle(items[0], items[1], func(req testserver.Request) any {
statusCode := http.StatusOK return stub.Response
if stub.Response.StatusCode != 0 {
statusCode = stub.Response.StatusCode
}
return stub.Response.Body, statusCode
}) })
} }

View File

@ -9,7 +9,7 @@
10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:folder_permissions 10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:folder_permissions
10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:validate_sync_patterns 10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:validate_sync_patterns
10:07:59 Debug: Path /Workspace/Users/[USERNAME]/.bundle/debug/default/files has type directory (ID: 0) pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync 10:07:59 Debug: Path /Workspace/Users/[USERNAME]/.bundle/debug/default/files has type directory (ID: 0) pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync
10:07:59 Debug: non-retriable error: pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true 10:07:59 Debug: non-retriable error: Not Found pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< {} pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < HTTP/0.0 000 Not Found (Error: Not Found) pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< {} pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < HTTP/0.0 000 OK pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< } pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < } pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true

View File

@ -78,12 +78,10 @@
10:07:59 Debug: No script defined for postinit, skipping pid=12345 mutator=initialize mutator=seq mutator=scripts.postinit 10:07:59 Debug: No script defined for postinit, skipping pid=12345 mutator=initialize mutator=seq mutator=scripts.postinit
10:07:59 Debug: Apply pid=12345 mutator=validate 10:07:59 Debug: Apply pid=12345 mutator=validate
10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files 10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files
< HTTP/1.1 404 Not Found
10:07:59 Debug: POST /api/2.0/workspace/mkdirs 10:07:59 Debug: POST /api/2.0/workspace/mkdirs
> { > {
> "path": "/Workspace/Users/[USERNAME]/.bundle/debug/default/files" > "path": "/Workspace/Users/[USERNAME]/.bundle/debug/default/files"
> } > }
< HTTP/1.1 200 OK
10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files 10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files
< HTTP/1.1 200 OK < HTTP/1.1 200 OK
< { < {

View File

@ -1,8 +1,8 @@
package acceptance_test package acceptance_test
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -14,7 +14,7 @@ import (
func StartCmdServer(t *testing.T) *testserver.Server { func StartCmdServer(t *testing.T) *testserver.Server {
server := testserver.New(t) server := testserver.New(t)
server.Handle("GET", "/", func(_ *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/", func(r testserver.Request) any {
q := r.URL.Query() q := r.URL.Query()
args := strings.Split(q.Get("args"), " ") args := strings.Split(q.Get("args"), " ")
@ -27,7 +27,7 @@ func StartCmdServer(t *testing.T) *testserver.Server {
defer Chdir(t, q.Get("cwd"))() defer Chdir(t, q.Get("cwd"))()
c := testcli.NewRunner(t, r.Context(), args...) c := testcli.NewRunner(t, context.Background(), args...)
c.Verbose = false c.Verbose = false
stdout, stderr, err := c.Run() stdout, stderr, err := c.Run()
result := map[string]any{ result := map[string]any{
@ -39,7 +39,7 @@ func StartCmdServer(t *testing.T) *testserver.Server {
exitcode = 1 exitcode = 1
} }
result["exitcode"] = exitcode result["exitcode"] = exitcode
return result, http.StatusOK return result
}) })
return server return server
} }

View File

@ -10,6 +10,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/databricks/cli/libs/testdiff" "github.com/databricks/cli/libs/testdiff"
"github.com/databricks/cli/libs/testserver"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -56,10 +57,7 @@ type ServerStub struct {
Pattern string Pattern string
// The response body to return. // The response body to return.
Response struct { Response testserver.Response
Body string
StatusCode int
}
} }
// FindConfigs finds all the config relevant for this test, // FindConfigs finds all the config relevant for this test,

View File

@ -6,3 +6,7 @@
"method": "GET", "method": "GET",
"path": "/custom/endpoint" "path": "/custom/endpoint"
} }
{
"method": "GET",
"path": "/api/2.0/workspace/get-status"
}

View File

@ -6,10 +6,16 @@
} }
>>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param >>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param
HTTP/1.1 201 Created HTTP/1.1 201 Created
Content-Type: application/json X-Custom-Header: hello
Date: (redacted) Date: (redacted)
Content-Length: (redacted) Content-Length: (redacted)
Content-Type: text/plain; charset=utf-8
custom custom
--- ---
response response
>>> errcode [CLI] workspace get-status /a/b/c
Error: Not Found
Exit code: 1

View File

@ -1,2 +1,4 @@
trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me
trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param
trace errcode $CLI workspace get-status /a/b/c

View File

@ -12,6 +12,8 @@ Response.Body = '''custom
response response
''' '''
Response.StatusCode = 201 Response.StatusCode = 201
[Server.Response.Headers]
"X-Custom-Header" = ["hello"]
[[Repls]] [[Repls]]
Old = 'Date: .*' Old = 'Date: .*'

View File

@ -1,14 +1,12 @@
package acceptance_test package acceptance_test
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/iam"
"github.com/gorilla/mux"
"github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
@ -23,7 +21,7 @@ var testUser = iam.User{
} }
func AddHandlers(server *testserver.Server) { func AddHandlers(server *testserver.Server) {
server.Handle("GET", "/api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/policies/clusters/list", func(req testserver.Request) any {
return compute.ListPoliciesResponse{ return compute.ListPoliciesResponse{
Policies: []compute.Policy{ Policies: []compute.Policy{
{ {
@ -35,10 +33,10 @@ func AddHandlers(server *testserver.Server) {
Name: "some-test-cluster-policy", Name: "some-test-cluster-policy",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/instance-pools/list", func(req testserver.Request) any {
return compute.ListInstancePools{ return compute.ListInstancePools{
InstancePools: []compute.InstancePoolAndStats{ InstancePools: []compute.InstancePoolAndStats{
{ {
@ -46,10 +44,10 @@ func AddHandlers(server *testserver.Server) {
InstancePoolId: "1234", InstancePoolId: "1234",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/clusters/list", func(req testserver.Request) any {
return compute.ListClustersResponse{ return compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{ Clusters: []compute.ClusterDetails{
{ {
@ -61,74 +59,59 @@ func AddHandlers(server *testserver.Server) {
ClusterId: "9876", ClusterId: "9876",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(req testserver.Request) any {
return testUser, http.StatusOK return testUser
}) })
server.Handle("GET", "/api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/workspace/get-status", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
return req.Workspace.WorkspaceGetStatus(path)
return fakeWorkspace.WorkspaceGetStatus(path)
}) })
server.Handle("POST", "/api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace/mkdirs", func(req testserver.Request) any {
request := workspace.Mkdirs{} var request workspace.Mkdirs
decoder := json.NewDecoder(r.Body) if err := json.Unmarshal(req.Body, &request); err != nil {
return testserver.Response{
err := decoder.Decode(&request) Body: fmt.Sprintf("internal error: %s", err),
if err != nil { StatusCode: http.StatusInternalServerError,
return internalError(err) }
} }
return fakeWorkspace.WorkspaceMkdirs(request) req.Workspace.WorkspaceMkdirs(request)
return ""
}) })
server.Handle("GET", "/api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/workspace/export", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
return req.Workspace.WorkspaceExport(path)
return fakeWorkspace.WorkspaceExport(path)
}) })
server.Handle("POST", "/api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace/delete", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
recursiveStr := r.URL.Query().Get("recursive") recursive := req.URL.Query().Get("recursive") == "true"
var recursive bool req.Workspace.WorkspaceDelete(path, recursive)
return ""
if recursiveStr == "true" {
recursive = true
} else {
recursive = false
}
return fakeWorkspace.WorkspaceDelete(path, recursive)
}) })
server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(req testserver.Request) any {
vars := mux.Vars(r) path := req.Vars["path"]
path := vars["path"] req.Workspace.WorkspaceFilesImportFile(path, req.Body)
return ""
body := new(bytes.Buffer)
_, err := body.ReadFrom(r.Body)
if err != nil {
return internalError(err)
}
return fakeWorkspace.WorkspaceFilesImportFile(path, body.Bytes())
}) })
server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(req testserver.Request) any {
return catalog.MetastoreAssignment{ return catalog.MetastoreAssignment{
DefaultCatalogName: "main", DefaultCatalogName: "main",
}, http.StatusOK MetastoreId: "45f8dcfe-1914-47bc-b00e-1a5b9fea2cfc",
WorkspaceId: 100200300400,
}
}) })
server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(req testserver.Request) any {
vars := mux.Vars(r) objectId := req.Vars["objectId"]
objectId := vars["objectId"]
return workspace.WorkspaceObjectPermissions{ return workspace.WorkspaceObjectPermissions{
ObjectId: objectId, ObjectId: objectId,
ObjectType: "DIRECTORY", ObjectType: "DIRECTORY",
@ -142,48 +125,44 @@ func AddHandlers(server *testserver.Server) {
}, },
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("POST", "/api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.1/jobs/create", func(req testserver.Request) any {
request := jobs.CreateJob{} var request jobs.CreateJob
decoder := json.NewDecoder(r.Body) if err := json.Unmarshal(req.Body, &request); err != nil {
return testserver.Response{
err := decoder.Decode(&request) Body: fmt.Sprintf("internal error: %s", err),
if err != nil { StatusCode: 500,
return internalError(err) }
} }
return fakeWorkspace.JobsCreate(request) req.Workspace.JobsCreate(request)
return ""
}) })
server.Handle("GET", "/api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/jobs/get", func(req testserver.Request) any {
jobId := r.URL.Query().Get("job_id") jobId := req.URL.Query().Get("job_id")
return req.Workspace.JobsGet(jobId)
return fakeWorkspace.JobsGet(jobId)
}) })
server.Handle("GET", "/api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/jobs/list", func(req testserver.Request) any {
return fakeWorkspace.JobsList() return req.Workspace.JobsList()
}) })
server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(_ testserver.Request) any {
return map[string]string{ return map[string]string{
"authorization_endpoint": server.URL + "oidc/v1/authorize", "authorization_endpoint": server.URL + "oidc/v1/authorize",
"token_endpoint": server.URL + "/oidc/v1/token", "token_endpoint": server.URL + "/oidc/v1/token",
}, http.StatusOK }
}) })
server.Handle("POST", "/oidc/v1/token", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/oidc/v1/token", func(_ testserver.Request) any {
return map[string]string{ return map[string]string{
"access_token": "oauth-token", "access_token": "oauth-token",
"expires_in": "3600", "expires_in": "3600",
"scope": "all-apis", "scope": "all-apis",
"token_type": "Bearer", "token_type": "Bearer",
}, http.StatusOK }
}) })
} }
func internalError(err error) (any, int) {
return fmt.Errorf("internal error: %w", err), http.StatusInternalServerError
}

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -33,40 +32,32 @@ func NewFakeWorkspace() *FakeWorkspace {
} }
} }
func (s *FakeWorkspace) WorkspaceGetStatus(path string) (workspace.ObjectInfo, int) { func (s *FakeWorkspace) WorkspaceGetStatus(path string) *workspace.ObjectInfo {
if s.directories[path] { if s.directories[path] {
return workspace.ObjectInfo{ return &workspace.ObjectInfo{
ObjectType: "DIRECTORY", ObjectType: "DIRECTORY",
Path: path, Path: path,
}, http.StatusOK }
} else if _, ok := s.files[path]; ok { } else if _, ok := s.files[path]; ok {
return workspace.ObjectInfo{ return &workspace.ObjectInfo{
ObjectType: "FILE", ObjectType: "FILE",
Path: path, Path: path,
Language: "SCALA", Language: "SCALA",
}, http.StatusOK }
} else { } else {
return workspace.ObjectInfo{}, http.StatusNotFound return nil
} }
} }
func (s *FakeWorkspace) WorkspaceMkdirs(request workspace.Mkdirs) (string, int) { func (s *FakeWorkspace) WorkspaceMkdirs(request workspace.Mkdirs) {
s.directories[request.Path] = true s.directories[request.Path] = true
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceExport(path string) ([]byte, int) { func (s *FakeWorkspace) WorkspaceExport(path string) []byte {
file := s.files[path] return s.files[path]
if file == nil {
return nil, http.StatusNotFound
}
return file, http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) (string, int) { func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) {
if !recursive { if !recursive {
s.files[path] = nil s.files[path] = nil
} else { } else {
@ -76,28 +67,26 @@ func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) (string, in
} }
} }
} }
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceFilesImportFile(path string, body []byte) (any, int) { func (s *FakeWorkspace) WorkspaceFilesImportFile(path string, body []byte) {
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
path = "/" + path path = "/" + path
} }
s.files[path] = body s.files[path] = body
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) (any, int) { func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) Response {
jobId := s.nextJobId jobId := s.nextJobId
s.nextJobId++ s.nextJobId++
jobSettings := jobs.JobSettings{} jobSettings := jobs.JobSettings{}
err := jsonConvert(request, &jobSettings) err := jsonConvert(request, &jobSettings)
if err != nil { if err != nil {
return internalError(err) return Response{
StatusCode: 400,
Body: fmt.Sprintf("Cannot convert request to jobSettings: %s", err),
}
} }
s.jobs[jobId] = jobs.Job{ s.jobs[jobId] = jobs.Job{
@ -105,32 +94,44 @@ func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) (any, int) {
Settings: &jobSettings, Settings: &jobSettings,
} }
return jobs.CreateResponse{JobId: jobId}, http.StatusOK return Response{
Body: jobs.CreateResponse{JobId: jobId},
}
} }
func (s *FakeWorkspace) JobsGet(jobId string) (any, int) { func (s *FakeWorkspace) JobsGet(jobId string) Response {
id := jobId id := jobId
jobIdInt, err := strconv.ParseInt(id, 10, 64) jobIdInt, err := strconv.ParseInt(id, 10, 64)
if err != nil { if err != nil {
return internalError(fmt.Errorf("failed to parse job id: %s", err)) return Response{
StatusCode: 400,
Body: fmt.Sprintf("Failed to parse job id: %s: %v", err, id),
}
} }
job, ok := s.jobs[jobIdInt] job, ok := s.jobs[jobIdInt]
if !ok { if !ok {
return jobs.Job{}, http.StatusNotFound return Response{
StatusCode: 404,
}
} }
return job, http.StatusOK return Response{
Body: job,
}
} }
func (s *FakeWorkspace) JobsList() (any, int) { func (s *FakeWorkspace) JobsList() Response {
list := make([]jobs.BaseJob, 0, len(s.jobs)) list := make([]jobs.BaseJob, 0, len(s.jobs))
for _, job := range s.jobs { for _, job := range s.jobs {
baseJob := jobs.BaseJob{} baseJob := jobs.BaseJob{}
err := jsonConvert(job, &baseJob) err := jsonConvert(job, &baseJob)
if err != nil { if err != nil {
return internalError(fmt.Errorf("failed to convert job to base job: %w", err)) return Response{
StatusCode: 400,
Body: fmt.Sprintf("failed to convert job to base job: %s", err),
}
} }
list = append(list, baseJob) list = append(list, baseJob)
@ -141,9 +142,11 @@ func (s *FakeWorkspace) JobsList() (any, int) {
return list[i].JobId < list[j].JobId return list[i].JobId < list[j].JobId
}) })
return jobs.ListJobsResponse{ return Response{
Jobs: list, Body: jobs.ListJobsResponse{
}, http.StatusOK Jobs: list,
},
}
} }
// jsonConvert saves input to a value pointed by output // jsonConvert saves input to a value pointed by output
@ -163,7 +166,3 @@ func jsonConvert(input, output any) error {
return nil return nil
} }
func internalError(err error) (string, int) {
return fmt.Sprintf("internal error: %s", err), http.StatusInternalServerError
}

View File

@ -5,14 +5,14 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"reflect"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/internal/testutil"
"github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/apierr"
) )
@ -29,10 +29,10 @@ type Server struct {
RecordRequests bool RecordRequests bool
IncludeRequestHeaders []string IncludeRequestHeaders []string
Requests []Request Requests []LoggedRequest
} }
type Request struct { type LoggedRequest struct {
Headers http.Header `json:"headers,omitempty"` Headers http.Header `json:"headers,omitempty"`
Method string `json:"method"` Method string `json:"method"`
Path string `json:"path"` Path string `json:"path"`
@ -40,6 +40,144 @@ type Request struct {
RawBody string `json:"raw_body,omitempty"` RawBody string `json:"raw_body,omitempty"`
} }
type Request struct {
Method string
URL *url.URL
Headers http.Header
Body []byte
Vars map[string]string
Workspace *FakeWorkspace
}
type Response struct {
StatusCode int
Headers http.Header
Body any
}
type encodedResponse struct {
StatusCode int
Headers http.Header
Body []byte
}
func NewRequest(t testutil.TestingT, r *http.Request, fakeWorkspace *FakeWorkspace) Request {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %s", err)
}
return Request{
Method: r.Method,
URL: r.URL,
Headers: r.Header,
Body: body,
Vars: mux.Vars(r),
Workspace: fakeWorkspace,
}
}
func normalizeResponse(t testutil.TestingT, resp any) encodedResponse {
result := normalizeResponseBody(t, resp)
if result.StatusCode == 0 {
result.StatusCode = 200
}
return result
}
func normalizeResponseBody(t testutil.TestingT, resp any) encodedResponse {
if isNil(resp) {
return encodedResponse{StatusCode: 404, Body: []byte{}}
}
respBytes, ok := resp.([]byte)
if ok {
return encodedResponse{
Body: respBytes,
Headers: getHeaders(respBytes),
}
}
respString, ok := resp.(string)
if ok {
return encodedResponse{
Body: []byte(respString),
Headers: getHeaders([]byte(respString)),
}
}
respStruct, ok := resp.(Response)
if ok {
bytesVal, isBytes := respStruct.Body.([]byte)
if isBytes {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: bytesVal,
}
}
stringVal, isString := respStruct.Body.(string)
if isString {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: []byte(stringVal),
}
}
respBytes, err := json.MarshalIndent(respStruct.Body, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
headers := respStruct.Headers
if headers == nil {
headers = getJsonHeaders()
}
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: headers,
Body: respBytes,
}
}
respBytes, err := json.MarshalIndent(resp, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
return encodedResponse{
Body: respBytes,
Headers: getJsonHeaders(),
}
}
func getJsonHeaders() http.Header {
return map[string][]string{
"Content-Type": {"application/json"},
}
}
func getHeaders(value []byte) http.Header {
if json.Valid(value) {
return getJsonHeaders()
} else {
return map[string][]string{
"Content-Type": {"text/plain"},
}
}
}
func New(t testutil.TestingT) *Server { func New(t testutil.TestingT) *Server {
router := mux.NewRouter() router := mux.NewRouter()
server := httptest.NewServer(router) server := httptest.NewServer(router)
@ -96,7 +234,7 @@ Response.StatusCode = <response status-code here>
return s return s
} }
type HandlerFunc func(fakeWorkspace *FakeWorkspace, req *http.Request) (resp any, statusCode int) type HandlerFunc func(req Request) any
func (s *Server) Handle(method, path string, handler HandlerFunc) { func (s *Server) Handle(method, path string, handler HandlerFunc) {
s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
@ -117,56 +255,22 @@ func (s *Server) Handle(method, path string, handler HandlerFunc) {
fakeWorkspace = s.fakeWorkspaces[token] fakeWorkspace = s.fakeWorkspaces[token]
} }
resp, statusCode := handler(fakeWorkspace, r) request := NewRequest(s.t, r, fakeWorkspace)
if s.RecordRequests { if s.RecordRequests {
body, err := io.ReadAll(r.Body) s.Requests = append(s.Requests, getLoggedRequest(request, s.IncludeRequestHeaders))
assert.NoError(s.t, err)
headers := make(http.Header)
for k, v := range r.Header {
if !slices.Contains(s.IncludeRequestHeaders, k) {
continue
}
for _, vv := range v {
headers.Add(k, vv)
}
}
req := Request{
Headers: headers,
Method: r.Method,
Path: r.URL.Path,
}
if json.Valid(body) {
req.Body = json.RawMessage(body)
} else {
req.RawBody = string(body)
}
s.Requests = append(s.Requests, req)
} }
w.Header().Set("Content-Type", "application/json") respAny := handler(request)
w.WriteHeader(statusCode) resp := normalizeResponse(s.t, respAny)
var respBytes []byte for k, v := range resp.Headers {
var err error w.Header()[k] = v
if respString, ok := resp.(string); ok {
respBytes = []byte(respString)
} else if respBytes0, ok := resp.([]byte); ok {
respBytes = respBytes0
} else {
respBytes, err = json.MarshalIndent(resp, "", " ")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} }
if _, err := w.Write(respBytes); err != nil { w.WriteHeader(resp.StatusCode)
http.Error(w, err.Error(), http.StatusInternalServerError)
if _, err := w.Write(resp.Body); err != nil {
s.t.Errorf("Failed to write response: %s", err)
return return
} }
}).Methods(method) }).Methods(method)
@ -182,3 +286,43 @@ func getToken(r *http.Request) string {
return header[len(prefix):] return header[len(prefix):]
} }
func getLoggedRequest(req Request, includedHeaders []string) LoggedRequest {
result := LoggedRequest{
Method: req.Method,
Path: req.URL.Path,
Headers: filterHeaders(req.Headers, includedHeaders),
}
if json.Valid(req.Body) {
result.Body = json.RawMessage(req.Body)
} else {
result.RawBody = string(req.Body)
}
return result
}
func filterHeaders(h http.Header, includedHeaders []string) http.Header {
headers := make(http.Header)
for k, v := range h {
if !slices.Contains(includedHeaders, k) {
continue
}
headers[k] = v
}
return headers
}
func isNil(i any) bool {
if i == nil {
return true
}
v := reflect.ValueOf(i)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice:
return v.IsNil()
default:
return false
}
}