databricks-cli/libs/testserver/server.go

338 lines
7.1 KiB
Go

package testserver
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"slices"
"strings"
"sync"
"github.com/gorilla/mux"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/databricks-sdk-go/apierr"
)
type Server struct {
*httptest.Server
Router *mux.Router
t testutil.TestingT
fakeWorkspaces map[string]*FakeWorkspace
mu *sync.Mutex
RecordRequests bool
IncludeRequestHeaders []string
Requests []LoggedRequest
}
type LoggedRequest struct {
Headers http.Header `json:"headers,omitempty"`
Method string `json:"method"`
Path string `json:"path"`
Body any `json:"body,omitempty"`
RawBody string `json:"raw_body,omitempty"`
}
type Request struct {
Method string
URL *url.URL
Headers http.Header
Body []byte
Vars map[string]string
Workspace *FakeWorkspace
}
type Response struct {
StatusCode int
Headers http.Header
Body any
}
type encodedResponse struct {
StatusCode int
Headers http.Header
Body []byte
}
func NewRequest(t testutil.TestingT, r *http.Request, fakeWorkspace *FakeWorkspace) Request {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %s", err)
}
return Request{
Method: r.Method,
URL: r.URL,
Headers: r.Header,
Body: body,
Vars: mux.Vars(r),
Workspace: fakeWorkspace,
}
}
func normalizeResponse(t testutil.TestingT, resp any) encodedResponse {
result := normalizeResponseBody(t, resp)
if result.StatusCode == 0 {
result.StatusCode = 200
}
return result
}
func normalizeResponseBody(t testutil.TestingT, resp any) encodedResponse {
if isNil(resp) {
t.Errorf("Handler must not return nil")
return encodedResponse{StatusCode: 500}
}
respBytes, ok := resp.([]byte)
if ok {
return encodedResponse{
Body: respBytes,
Headers: getHeaders(respBytes),
}
}
respString, ok := resp.(string)
if ok {
return encodedResponse{
Body: []byte(respString),
Headers: getHeaders([]byte(respString)),
}
}
respStruct, ok := resp.(Response)
if ok {
if isNil(respStruct.Body) {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: []byte{},
}
}
bytesVal, isBytes := respStruct.Body.([]byte)
if isBytes {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: bytesVal,
}
}
stringVal, isString := respStruct.Body.(string)
if isString {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: []byte(stringVal),
}
}
respBytes, err := json.MarshalIndent(respStruct.Body, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
headers := respStruct.Headers
if headers == nil {
headers = getJsonHeaders()
}
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: headers,
Body: respBytes,
}
}
respBytes, err := json.MarshalIndent(resp, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
return encodedResponse{
Body: respBytes,
Headers: getJsonHeaders(),
}
}
func getJsonHeaders() http.Header {
return map[string][]string{
"Content-Type": {"application/json"},
}
}
func getHeaders(value []byte) http.Header {
if json.Valid(value) {
return getJsonHeaders()
} else {
return map[string][]string{
"Content-Type": {"text/plain"},
}
}
}
func New(t testutil.TestingT) *Server {
router := mux.NewRouter()
server := httptest.NewServer(router)
t.Cleanup(server.Close)
s := &Server{
Server: server,
Router: router,
t: t,
mu: &sync.Mutex{},
fakeWorkspaces: map[string]*FakeWorkspace{},
}
// Set up the not found handler as fallback
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pattern := r.Method + " " + r.URL.Path
t.Errorf(`
----------------------------------------
No stub found for pattern: %s
To stub a response for this request, you can add
the following to test.toml:
[[Server]]
Pattern = %q
Response.Body = '''
<response body here>
'''
Response.StatusCode = <response status-code here>
----------------------------------------
`, pattern, pattern)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
resp := apierr.APIError{
Message: "No stub found for pattern: " + pattern,
}
respBytes, err := json.Marshal(resp)
if err != nil {
t.Errorf("JSON encoding error: %s", err)
respBytes = []byte("{\"message\": \"JSON encoding error\"}")
}
if _, err := w.Write(respBytes); err != nil {
t.Errorf("Response write error: %s", err)
}
})
return s
}
type HandlerFunc func(req Request) any
func (s *Server) Handle(method, path string, handler HandlerFunc) {
s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
// For simplicity we process requests sequentially. It's fast enough because
// we don't do any IO except reading and writing request/response bodies.
s.mu.Lock()
defer s.mu.Unlock()
// Each test uses unique DATABRICKS_TOKEN, we simulate each token having
// it's own fake fakeWorkspace to avoid interference between tests.
var fakeWorkspace *FakeWorkspace = nil
token := getToken(r)
if token != "" {
if _, ok := s.fakeWorkspaces[token]; !ok {
s.fakeWorkspaces[token] = NewFakeWorkspace()
}
fakeWorkspace = s.fakeWorkspaces[token]
}
request := NewRequest(s.t, r, fakeWorkspace)
if s.RecordRequests {
s.Requests = append(s.Requests, getLoggedRequest(request, s.IncludeRequestHeaders))
}
respAny := handler(request)
resp := normalizeResponse(s.t, respAny)
for k, v := range resp.Headers {
w.Header()[k] = v
}
w.WriteHeader(resp.StatusCode)
if _, err := w.Write(resp.Body); err != nil {
s.t.Errorf("Failed to write response: %s", err)
return
}
}).Methods(method)
}
func getToken(r *http.Request) string {
header := r.Header.Get("Authorization")
prefix := "Bearer "
if !strings.HasPrefix(header, prefix) {
return ""
}
return header[len(prefix):]
}
func getLoggedRequest(req Request, includedHeaders []string) LoggedRequest {
result := LoggedRequest{
Method: req.Method,
Path: req.URL.Path,
Headers: filterHeaders(req.Headers, includedHeaders),
}
if json.Valid(req.Body) {
result.Body = json.RawMessage(req.Body)
} else {
result.RawBody = string(req.Body)
}
return result
}
func filterHeaders(h http.Header, includedHeaders []string) http.Header {
headers := make(http.Header)
for k, v := range h {
if !slices.Contains(includedHeaders, k) {
continue
}
headers[k] = v
}
return headers
}
func isNil(i any) bool {
if i == nil {
return true
}
v := reflect.ValueOf(i)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice:
return v.IsNil()
default:
return false
}
}