Add acceptance tests for auth resolution

This commit is contained in:
Shreyas Goenka 2025-02-03 19:24:41 +01:00
parent 2eb9abb5ee
commit a7e785d0e8
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
19 changed files with 151 additions and 9 deletions

View File

@ -35,7 +35,7 @@ var (
// In order to debug CLI running under acceptance test, set this to full subtest name, e.g. "bundle/variables/empty"
// Then install your breakpoints and click "debug test" near TestAccept in VSCODE.
// example: var SingleTest = "bundle/variables/empty"
var SingleTest = ""
var SingleTest = "auth/oauth"
// If enabled, instead of compiling and running CLI externally, we'll start in-process server that accepts and runs
// CLI commands. The $CLI in test scripts is a helper that just forwards command-line arguments to this server (see bin/callserver.py).
@ -120,6 +120,7 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
if cloudEnv == "" {
defaultServer := testserver.New(t)
defaultServer.HandleUnknown()
AddHandlers(defaultServer)
// Redirect API access to local server:
t.Setenv("DATABRICKS_HOST", defaultServer.URL)
@ -156,6 +157,8 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient)
testdiff.PrepareReplacementsUUID(t, &repls)
testdiff.PrepareReplacementsDevVersion(t, &repls)
testdiff.PrepareReplacementSdkVersion(t, &repls)
testdiff.PrepareReplacementsGoVersion(t, &repls)
testDirs := getTests(t)
require.NotEmpty(t, testDirs)
@ -252,7 +255,9 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
// server otherwise is a shared resource.
if len(config.Server) > 0 || config.RecordRequests {
server = testserver.New(t)
server.HandleUnknown()
server.RecordRequests = config.RecordRequests
server.IncludeReqHeaders = config.IncludeReqHeaders
// If no custom server stubs are defined, add the default handlers.
if len(config.Server) == 0 {
@ -294,8 +299,12 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
for _, req := range server.Requests {
reqJson, err := json.Marshal(req)
if err == nil {
}
require.NoError(t, err)
// if
line := fmt.Sprintf("%s\n", reqJson)
_, err = f.WriteString(line)
require.NoError(t, err)

View File

@ -0,0 +1 @@
{"method":"GET","path":"/api/2.0/preview/scim/v2/Me","headers":{"Authorization":"Basic dXNlcm5hbWU6cGFzc3dvcmQ=","User-Agent":"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/darwin cmd/current-user_me cmd-exec-id/[UUID] auth/basic"},"body":null}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,8 @@
# Unset the token which is configured by default
# in acceptance tests
export DATABRICKS_TOKEN=""
export DATABRICKS_USERNAME=username
export DATABRICKS_PASSWORD=password
$CLI current-user me

View File

@ -0,0 +1,2 @@
RecordRequests = true
IncludeReqHeaders = ["Authorization", "User-Agent"]

View File

@ -0,0 +1,3 @@
{"method":"GET","path":"/oidc/.well-known/oauth-authorization-server","headers":{"User-Agent":"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/darwin"},"body":""}
{"method":"POST","path":"/oidc/v1/token","headers":{"Authorization":"Basic Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=","User-Agent":"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/darwin"},"body":"grant_type=client_credentials\u0026scope=all-apis"}
{"method":"GET","path":"/api/2.0/preview/scim/v2/Me","headers":{"Authorization":"Bearer oauth-token","User-Agent":"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/darwin cmd/current-user_me cmd-exec-id/[UUID] auth/oauth-m2m"},"body":""}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,8 @@
# Unset the token which is configured by default
# in acceptance tests
export DATABRICKS_TOKEN=""
export DATABRICKS_CLIENT_ID=client_id
export DATABRICKS_CLIENT_SECRET=client_secret
$CLI current-user me

View File

@ -0,0 +1,2 @@
RecordRequests = true
IncludeReqHeaders = ["Authorization", "User-Agent"]

View File

@ -0,0 +1 @@
{"method":"GET","path":"/api/2.0/preview/scim/v2/Me","headers":{"Authorization":"Bearer dapi1234","User-Agent":"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/darwin cmd/current-user_me cmd-exec-id/[UUID] auth/pat"},"body":null}

View File

@ -0,0 +1,4 @@
{
"id":"[USERID]",
"userName":"[USERNAME]"
}

View File

@ -0,0 +1,3 @@
export DATABRICKS_TOKEN=dapi1234
$CLI current-user me

View File

@ -0,0 +1,2 @@
RecordRequests = true
IncludeReqHeaders = ["Authorization", "User-Agent"]

View File

@ -47,6 +47,8 @@ type TestConfig struct {
// Record the requests made to the server and write them as output to
// out.requests.txt
RecordRequests bool
// Include the following request headers in the recorded requests
IncludeReqHeaders []string
}
type ServerStub struct {

View File

@ -94,4 +94,20 @@ func AddHandlers(server *testserver.Server) {
server.Handle("POST /api/2.0/workspace/mkdirs", func(r *http.Request) (any, error) {
return "{}", nil
})
server.Handle("GET /oidc/.well-known/oauth-authorization-server", func(r *http.Request) (any, error) {
return map[string]string{
"authorization_endpoint": server.URL + "oidc/v1/authorize",
"token_endpoint": server.URL + "/oidc/v1/token",
}, nil
})
server.Handle("POST /oidc/v1/token", func(r *http.Request) (any, error) {
return map[string]string{
"access_token": "oauth-token",
"expires_in": "3600",
"scope": "all-apis",
"token_type": "Bearer",
}, nil
})
}

View File

@ -14,7 +14,7 @@ import (
var OverwriteMode = false
func init() {
flag.BoolVar(&OverwriteMode, "update", false, "Overwrite golden files")
flag.BoolVar(&OverwriteMode, "update", true, "Overwrite golden files")
}
func ReadFile(t testutil.TestingT, ctx context.Context, filename string) string {

View File

@ -12,6 +12,7 @@ import (
"github.com/databricks/cli/libs/iamutil"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/iam"
"golang.org/x/mod/semver"
)
const (
@ -208,3 +209,25 @@ func PrepareReplacementsDevVersion(t testutil.TestingT, r *ReplacementsContext)
t.Helper()
r.append(devVersionRegex, "[DEV_VERSION]")
}
func PrepareReplacementSdkVersion(t testutil.TestingT, r *ReplacementsContext) {
t.Helper()
r.Set(databricks.Version(), "[SDK_VERSION]")
}
func goVersion() string {
gv := runtime.Version()
ssv := strings.ReplaceAll(gv, "go", "v")
sv := semver.Canonical(ssv)
return strings.TrimPrefix(sv, "v")
}
func PrepareReplacementsGoVersion(t testutil.TestingT, r *ReplacementsContext) {
t.Helper()
r.Set(goVersion(), "[GO_VERSION]")
}
func PrepareReplaceOS(t testutil.TestingT, r *ReplacementsContext) {
t.Helper()
r.Set(runtime.GOOS, "[OS]")
}

View File

@ -1,6 +1,7 @@
package testdiff
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
@ -44,3 +45,11 @@ func TestReplacement_TemporaryDirectory(t *testing.T) {
assert.Equal(t, "/tmp/.../tail", repls.Replace("/tmp/foo/bar/qux/tail"))
}
func TestReplacement_OS(t *testing.T) {
var repls ReplacementsContext
PrepareReplaceOS(t, &repls)
assert.Equal(t, "[OS]", repls.Replace(runtime.GOOS))
}

View File

@ -2,9 +2,12 @@ package testserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"slices"
"github.com/stretchr/testify/assert"
@ -17,15 +20,17 @@ type Server struct {
t testutil.TestingT
RecordRequests bool
RecordRequests bool
IncludeReqHeaders []string
Requests []Request
}
type Request struct {
Method string `json:"method"`
Path string `json:"path"`
Body any `json:"body"`
Method string `json:"method"`
Path string `json:"path"`
Headers map[string]string `json:"headers,omitempty"`
Body any `json:"body,omitempty"`
}
func New(t testutil.TestingT) *Server {
@ -40,6 +45,23 @@ func New(t testutil.TestingT) *Server {
}
}
func (s *Server) HandleUnknown() {
s.Handle("/", func(req *http.Request) (any, error) {
msg := fmt.Sprintf(`
unknown API request received. Please add a handler for this request in
your test. You can copy the following snippet in your test.toml file:
[[Server]]
Pattern = %s %s
Response = '''
<response here>
'''`, req.Method, req.URL.Path)
s.t.Fatalf(msg)
return nil, errors.New("unknown API request")
})
}
type HandlerFunc func(req *http.Request) (resp any, err error)
func (s *Server) Handle(pattern string, handler HandlerFunc) {
@ -54,10 +76,29 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) {
body, err := io.ReadAll(r.Body)
assert.NoError(s.t, err)
headers := make(map[string]string)
for k, v := range r.Header {
if !slices.Contains(s.IncludeReqHeaders, k) {
continue
}
if len(v) == 0 {
continue
}
headers[k] = v[0]
}
var reqBody any
if len(body) > 0 && body[0] == '{' {
reqBody = json.RawMessage(body)
} else {
reqBody = string(body)
}
s.Requests = append(s.Requests, Request{
Method: r.Method,
Path: r.URL.Path,
Body: json.RawMessage(body),
Method: r.Method,
Path: r.URL.Path,
Headers: headers,
Body: reqBody,
})
}