diff --git a/acceptance/acceptance_test.go b/acceptance/acceptance_test.go index c7b1151ab..c0fa960b6 100644 --- a/acceptance/acceptance_test.go +++ b/acceptance/acceptance_test.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "io" + "net/http" "os" "os/exec" "path/filepath" @@ -27,6 +28,7 @@ import ( "github.com/databricks/cli/libs/testserver" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -263,8 +265,23 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont if len(config.Server) > 0 || config.RecordRequests { server = testserver.New(t) - server.RecordRequests = config.RecordRequests - server.IncludeRequestHeaders = config.IncludeRequestHeaders + if config.RecordRequests { + requestsPath := filepath.Join(tmpDir, "out.requests.txt") + server.RecordRequestsCallback = func(request *testserver.Request) { + req := getLoggedRequest(request, config.IncludeRequestHeaders) + reqJson, err := json.MarshalIndent(req, "", " ") + assert.NoErrorf(t, err, "Failed to indent: %#v", req) + + reqJsonWithRepls := repls.Replace(string(reqJson)) + + f, err := os.OpenFile(requestsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + assert.NoError(t, err) + defer f.Close() + + _, err = f.WriteString(reqJsonWithRepls + "\n") + assert.NoError(t, err) + } + } // We want later stubs takes precedence, because then leaf configs take precedence over parent directory configs // In gorilla/mux earlier handlers take precedence, so we need to reverse the order @@ -345,25 +362,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont cmd.Dir = tmpDir err = cmd.Run() - // Write the requests made to the server to a output file if the test is - // configured to record requests. - if config.RecordRequests { - f, err := os.OpenFile(filepath.Join(tmpDir, "out.requests.txt"), os.O_CREATE|os.O_WRONLY, 0o644) - require.NoError(t, err) - - for _, req := range server.Requests { - reqJson, err := json.MarshalIndent(req, "", " ") - require.NoErrorf(t, err, "Failed to indent: %#v", req) - - reqJsonWithRepls := repls.Replace(string(reqJson)) - _, err = f.WriteString(reqJsonWithRepls + "\n") - require.NoError(t, err) - } - - err = f.Close() - require.NoError(t, err) - } - // Include exit code in output (if non-zero) formatOutput(out, err) require.NoError(t, out.Close()) @@ -670,3 +668,38 @@ func RunCommand(t *testing.T, args []string, dir string) { t.Logf("%s output: %s", args, out) } } + +type LoggedRequest struct { + Headers http.Header `json:"headers,omitempty"` + Method string `json:"method"` + Path string `json:"path"` + Body any `json:"body,omitempty"` + RawBody string `json:"raw_body,omitempty"` +} + +func getLoggedRequest(req *testserver.Request, includedHeaders []string) LoggedRequest { + result := LoggedRequest{ + Method: req.Method, + Path: req.URL.Path, + Headers: filterHeaders(req.Headers, includedHeaders), + } + + if json.Valid(req.Body) { + result.Body = json.RawMessage(req.Body) + } else { + result.RawBody = string(req.Body) + } + + return result +} + +func filterHeaders(h http.Header, includedHeaders []string) http.Header { + headers := make(http.Header) + for k, v := range h { + if !slices.Contains(includedHeaders, k) { + continue + } + headers[k] = v + } + return headers +} diff --git a/libs/testserver/server.go b/libs/testserver/server.go index fa15973d7..a10ddf4d8 100644 --- a/libs/testserver/server.go +++ b/libs/testserver/server.go @@ -7,7 +7,6 @@ import ( "net/http/httptest" "net/url" "reflect" - "slices" "strings" "sync" @@ -26,18 +25,7 @@ type Server struct { fakeWorkspaces map[string]*FakeWorkspace mu *sync.Mutex - RecordRequests bool - IncludeRequestHeaders []string - - Requests []LoggedRequest -} - -type LoggedRequest struct { - Headers http.Header `json:"headers,omitempty"` - Method string `json:"method"` - Path string `json:"path"` - Body any `json:"body,omitempty"` - RawBody string `json:"raw_body,omitempty"` + RecordRequestsCallback func(request *Request) } type Request struct { @@ -265,10 +253,9 @@ func (s *Server) Handle(method, path string, handler HandlerFunc) { } request := NewRequest(s.t, r, fakeWorkspace) - if s.RecordRequests { - s.Requests = append(s.Requests, getLoggedRequest(request, s.IncludeRequestHeaders)) + if s.RecordRequestsCallback != nil { + s.RecordRequestsCallback(&request) } - respAny := handler(request) resp := normalizeResponse(s.t, respAny) @@ -296,33 +283,6 @@ func getToken(r *http.Request) string { return header[len(prefix):] } -func getLoggedRequest(req Request, includedHeaders []string) LoggedRequest { - result := LoggedRequest{ - Method: req.Method, - Path: req.URL.Path, - Headers: filterHeaders(req.Headers, includedHeaders), - } - - if json.Valid(req.Body) { - result.Body = json.RawMessage(req.Body) - } else { - result.RawBody = string(req.Body) - } - - return result -} - -func filterHeaders(h http.Header, includedHeaders []string) http.Header { - headers := make(http.Header) - for k, v := range h { - if !slices.Contains(includedHeaders, k) { - continue - } - headers[k] = v - } - return headers -} - func isNil(i any) bool { if i == nil { return true