mirror of https://github.com/databricks/cli.git
Compare commits
30 Commits
3d09416e98
...
73a3399beb
Author | SHA1 | Date |
---|---|---|
|
73a3399beb | |
|
24ac8d8d59 | |
|
5d392acbef | |
|
272ce61302 | |
|
878fa80322 | |
|
8d849fe868 | |
|
ca08796f77 | |
|
fc23aa584d | |
|
6af6b55832 | |
|
865964e029 | |
|
41999fbe87 | |
|
d2bead3fe6 | |
|
11c37673a6 | |
|
18d3fea34e | |
|
b7ff019b60 | |
|
bb35ca090f | |
|
d037ec32a1 | |
|
89d3b1a4df | |
|
37067ef933 | |
|
171c3fdd75 | |
|
dc44dbd667 | |
|
b044a6c0e0 | |
|
7636c55ba9 | |
|
e88fd0a5c0 | |
|
6c32a0df7a | |
|
7eca34a7b2 | |
|
6277cf24c6 | |
|
6a8b2f452f | |
|
712e2919f5 | |
|
882ccba0f5 |
4
NOTICE
4
NOTICE
|
@ -114,3 +114,7 @@ dario.cat/mergo
|
|||
Copyright (c) 2013 Dario Castañé. All rights reserved.
|
||||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
https://github.com/darccio/mergo/blob/master/LICENSE
|
||||
|
||||
https://github.com/gorilla/mux
|
||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||
https://github.com/gorilla/mux/blob/main/LICENSE
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
|
@ -26,6 +27,7 @@ import (
|
|||
"github.com/databricks/cli/libs/testdiff"
|
||||
"github.com/databricks/cli/libs/testserver"
|
||||
"github.com/databricks/databricks-sdk-go"
|
||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -72,7 +74,8 @@ func TestInprocessMode(t *testing.T) {
|
|||
if InprocessMode {
|
||||
t.Skip("Already tested by TestAccept")
|
||||
}
|
||||
require.Equal(t, 1, testAccept(t, true, "selftest"))
|
||||
require.Equal(t, 1, testAccept(t, true, "selftest/basic"))
|
||||
require.Equal(t, 1, testAccept(t, true, "selftest/server"))
|
||||
}
|
||||
|
||||
func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
|
||||
|
@ -118,14 +121,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
|
|||
uvCache := getUVDefaultCacheDir(t)
|
||||
t.Setenv("UV_CACHE_DIR", uvCache)
|
||||
|
||||
ctx := context.Background()
|
||||
cloudEnv := os.Getenv("CLOUD_ENV")
|
||||
|
||||
if cloudEnv == "" {
|
||||
defaultServer := testserver.New(t)
|
||||
AddHandlers(defaultServer)
|
||||
// Redirect API access to local server:
|
||||
t.Setenv("DATABRICKS_HOST", defaultServer.URL)
|
||||
t.Setenv("DATABRICKS_DEFAULT_HOST", defaultServer.URL)
|
||||
|
||||
homeDir := t.TempDir()
|
||||
// Do not read user's ~/.databrickscfg
|
||||
|
@ -148,27 +149,12 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
|
|||
// do it last so that full paths match first:
|
||||
repls.SetPath(buildDir, "[BUILD_DIR]")
|
||||
|
||||
var config databricks.Config
|
||||
if cloudEnv == "" {
|
||||
// use fake token for local tests
|
||||
config = databricks.Config{Token: "dbapi1234"}
|
||||
} else {
|
||||
// non-local tests rely on environment variables
|
||||
config = databricks.Config{}
|
||||
}
|
||||
workspaceClient, err := databricks.NewWorkspaceClient(&config)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := workspaceClient.CurrentUser.Me(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
testdiff.PrepareReplacementsUser(t, &repls, *user)
|
||||
testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient)
|
||||
testdiff.PrepareReplacementsUUID(t, &repls)
|
||||
testdiff.PrepareReplacementsDevVersion(t, &repls)
|
||||
testdiff.PrepareReplacementSdkVersion(t, &repls)
|
||||
testdiff.PrepareReplacementsGoVersion(t, &repls)
|
||||
|
||||
repls.Repls = append(repls.Repls, testdiff.Replacement{Old: regexp.MustCompile("dbapi[0-9a-f]+"), New: "[DATABRICKS_TOKEN]"})
|
||||
|
||||
testDirs := getTests(t)
|
||||
require.NotEmpty(t, testDirs)
|
||||
|
||||
|
@ -180,8 +166,7 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
|
|||
}
|
||||
|
||||
for _, dir := range testDirs {
|
||||
testName := strings.ReplaceAll(dir, "\\", "/")
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Run(dir, func(t *testing.T) {
|
||||
if !InprocessMode {
|
||||
t.Parallel()
|
||||
}
|
||||
|
@ -203,7 +188,8 @@ func getTests(t *testing.T) []string {
|
|||
name := filepath.Base(path)
|
||||
if name == EntryPointScript {
|
||||
// Presence of 'script' marks a test case in this directory
|
||||
testDirs = append(testDirs, filepath.Dir(path))
|
||||
testName := filepath.ToSlash(filepath.Dir(path))
|
||||
testDirs = append(testDirs, testName)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
@ -239,7 +225,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
|
|||
}
|
||||
|
||||
repls.SetPathWithParents(tmpDir, "[TMPDIR]")
|
||||
repls.Repls = append(repls.Repls, config.Repls...)
|
||||
|
||||
scriptContents := readMergedScriptContents(t, dir)
|
||||
testutil.WriteFile(t, filepath.Join(tmpDir, EntryPointScript), scriptContents)
|
||||
|
@ -253,38 +238,79 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
|
|||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
var workspaceClient *databricks.WorkspaceClient
|
||||
var user iam.User
|
||||
|
||||
// Start a new server with a custom configuration if the acceptance test
|
||||
// specifies a custom server stubs.
|
||||
var server *testserver.Server
|
||||
|
||||
// Start a new server for this test if either:
|
||||
// 1. A custom server spec is defined in the test configuration.
|
||||
// 2. The test is configured to record requests and assert on them. We need
|
||||
// a duplicate of the default server to record requests because the default
|
||||
// server otherwise is a shared resource.
|
||||
if len(config.Server) > 0 || config.RecordRequests {
|
||||
server = testserver.New(t)
|
||||
server.RecordRequests = config.RecordRequests
|
||||
server.IncludeRequestHeaders = config.IncludeRequestHeaders
|
||||
if cloudEnv == "" {
|
||||
// Start a new server for this test if either:
|
||||
// 1. A custom server spec is defined in the test configuration.
|
||||
// 2. The test is configured to record requests and assert on them. We need
|
||||
// a duplicate of the default server to record requests because the default
|
||||
// server otherwise is a shared resource.
|
||||
|
||||
// If no custom server stubs are defined, add the default handlers.
|
||||
if len(config.Server) == 0 {
|
||||
databricksLocalHost := os.Getenv("DATABRICKS_DEFAULT_HOST")
|
||||
|
||||
if len(config.Server) > 0 || config.RecordRequests {
|
||||
server = testserver.New(t)
|
||||
server.RecordRequests = config.RecordRequests
|
||||
server.IncludeRequestHeaders = config.IncludeRequestHeaders
|
||||
|
||||
for _, stub := range config.Server {
|
||||
require.NotEmpty(t, stub.Pattern)
|
||||
items := strings.Split(stub.Pattern, " ")
|
||||
require.Len(t, items, 2)
|
||||
server.Handle(items[0], items[1], func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) {
|
||||
statusCode := http.StatusOK
|
||||
if stub.Response.StatusCode != 0 {
|
||||
statusCode = stub.Response.StatusCode
|
||||
}
|
||||
return stub.Response.Body, statusCode
|
||||
})
|
||||
}
|
||||
|
||||
// The earliest handlers take precedence, add default handlers last
|
||||
AddHandlers(server)
|
||||
databricksLocalHost = server.URL
|
||||
}
|
||||
|
||||
for _, stub := range config.Server {
|
||||
require.NotEmpty(t, stub.Pattern)
|
||||
server.Handle(stub.Pattern, func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) {
|
||||
statusCode := http.StatusOK
|
||||
if stub.Response.StatusCode != 0 {
|
||||
statusCode = stub.Response.StatusCode
|
||||
}
|
||||
return stub.Response.Body, statusCode
|
||||
})
|
||||
// Each local test should use a new token that will result into a new fake workspace,
|
||||
// so that test don't interfere with each other.
|
||||
tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
config := databricks.Config{
|
||||
Host: databricksLocalHost,
|
||||
Token: "dbapi" + tokenSuffix,
|
||||
}
|
||||
cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+server.URL)
|
||||
workspaceClient, err = databricks.NewWorkspaceClient(&config)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+config.Host)
|
||||
cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+config.Token)
|
||||
|
||||
// For the purposes of replacements, use testUser.
|
||||
// Note, users might have overriden /api/2.0/preview/scim/v2/Me but that should not affect the replacement:
|
||||
user = testUser
|
||||
} else {
|
||||
// Use whatever authentication mechanism is configured by the test runner.
|
||||
workspaceClient, err = databricks.NewWorkspaceClient(&databricks.Config{})
|
||||
require.NoError(t, err)
|
||||
pUser, err := workspaceClient.CurrentUser.Me(context.Background())
|
||||
require.NoError(t, err, "Failed to get current user")
|
||||
user = *pUser
|
||||
}
|
||||
|
||||
testdiff.PrepareReplacementsUser(t, &repls, user)
|
||||
testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient)
|
||||
|
||||
// Must be added PrepareReplacementsUser, otherwise conflicts with [USERNAME]
|
||||
testdiff.PrepareReplacementsUUID(t, &repls)
|
||||
|
||||
// User replacements come last:
|
||||
repls.Repls = append(repls.Repls, config.Repls...)
|
||||
|
||||
if coverDir != "" {
|
||||
// Creating individual coverage directory for each test, because writing to the same one
|
||||
// results in sporadic failures like this one (only if tests are running in parallel):
|
||||
|
@ -295,15 +321,6 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
|
|||
cmd.Env = append(cmd.Env, "GOCOVERDIR="+coverDir)
|
||||
}
|
||||
|
||||
// Each local test should use a new token that will result into a new fake workspace,
|
||||
// so that test don't interfere with each other.
|
||||
if cloudEnv == "" {
|
||||
tokenSuffix := strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
token := "dbapi" + tokenSuffix
|
||||
cmd.Env = append(cmd.Env, "DATABRICKS_TOKEN="+token)
|
||||
repls.Set(token, "[DATABRICKS_TOKEN]")
|
||||
}
|
||||
|
||||
// Write combined output to a file
|
||||
out, err := os.Create(filepath.Join(tmpDir, "output.txt"))
|
||||
require.NoError(t, err)
|
||||
|
@ -320,7 +337,7 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
|
|||
|
||||
for _, req := range server.Requests {
|
||||
reqJson, err := json.MarshalIndent(req, "", " ")
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "Failed to indent: %#v", req)
|
||||
|
||||
reqJsonWithRepls := repls.Replace(string(reqJson))
|
||||
_, err = f.WriteString(reqJsonWithRepls + "\n")
|
||||
|
|
|
@ -13,13 +13,13 @@
|
|||
|
||||
=== Inside the bundle, profile flag not matching bundle host. Badness: should use profile from flag instead and not fail
|
||||
>>> errcode [CLI] current-user me -p profile_name
|
||||
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL]
|
||||
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET]
|
||||
|
||||
Exit code: 1
|
||||
|
||||
=== Inside the bundle, target and not matching profile
|
||||
>>> errcode [CLI] current-user me -t dev -p profile_name
|
||||
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_URL]
|
||||
Error: cannot resolve bundle auth configuration: config host mismatch: profile uses host https://non-existing-subdomain.databricks.com, but CLI configured to use [DATABRICKS_TARGET]
|
||||
|
||||
Exit code: 1
|
||||
|
||||
|
|
|
@ -5,4 +5,8 @@ Badness = "When -p flag is used inside the bundle folder for any CLI commands, C
|
|||
# This is a workaround to replace DATABRICKS_URL with DATABRICKS_HOST
|
||||
[[Repls]]
|
||||
Old='DATABRICKS_HOST'
|
||||
New='DATABRICKS_URL'
|
||||
New='DATABRICKS_TARGET'
|
||||
|
||||
[[Repls]]
|
||||
Old='DATABRICKS_URL'
|
||||
New='DATABRICKS_TARGET'
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"headers": {
|
||||
"Authorization": [
|
||||
"Basic [ENCODED_AUTH]"
|
||||
],
|
||||
"User-Agent": [
|
||||
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/basic"
|
||||
]
|
||||
},
|
||||
"method": "GET",
|
||||
"path": "/api/2.0/preview/scim/v2/Me"
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"id":"[USERID]",
|
||||
"userName":"[USERNAME]"
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
# Unset the token which is configured by default
|
||||
# in acceptance tests
|
||||
export DATABRICKS_TOKEN=""
|
||||
|
||||
export DATABRICKS_USERNAME=username
|
||||
export DATABRICKS_PASSWORD=password
|
||||
|
||||
$CLI current-user me
|
|
@ -0,0 +1,4 @@
|
|||
# "username:password" in base64 is dXNlcm5hbWU6cGFzc3dvcmQ=, expect to see this in Authorization header
|
||||
[[Repls]]
|
||||
Old = "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
New = "[ENCODED_AUTH]"
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"headers": {
|
||||
"User-Agent": [
|
||||
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS]"
|
||||
]
|
||||
},
|
||||
"method": "GET",
|
||||
"path": "/oidc/.well-known/oauth-authorization-server"
|
||||
}
|
||||
{
|
||||
"headers": {
|
||||
"Authorization": [
|
||||
"Basic [ENCODED_AUTH]"
|
||||
],
|
||||
"User-Agent": [
|
||||
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS]"
|
||||
]
|
||||
},
|
||||
"method": "POST",
|
||||
"path": "/oidc/v1/token",
|
||||
"raw_body": "grant_type=client_credentials\u0026scope=all-apis"
|
||||
}
|
||||
{
|
||||
"headers": {
|
||||
"Authorization": [
|
||||
"Bearer oauth-token"
|
||||
],
|
||||
"User-Agent": [
|
||||
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/oauth-m2m"
|
||||
]
|
||||
},
|
||||
"method": "GET",
|
||||
"path": "/api/2.0/preview/scim/v2/Me"
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"id":"[USERID]",
|
||||
"userName":"[USERNAME]"
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
# Unset the token which is configured by default
|
||||
# in acceptance tests
|
||||
export DATABRICKS_TOKEN=""
|
||||
|
||||
export DATABRICKS_CLIENT_ID=client_id
|
||||
export DATABRICKS_CLIENT_SECRET=client_secret
|
||||
|
||||
$CLI current-user me
|
|
@ -0,0 +1,5 @@
|
|||
# "client_id:client_secret" in base64 is Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ=, expect to
|
||||
# see this in Authorization header
|
||||
[[Repls]]
|
||||
Old = "Y2xpZW50X2lkOmNsaWVudF9zZWNyZXQ="
|
||||
New = "[ENCODED_AUTH]"
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"headers": {
|
||||
"Authorization": [
|
||||
"Bearer dapi1234"
|
||||
],
|
||||
"User-Agent": [
|
||||
"cli/[DEV_VERSION] databricks-sdk-go/[SDK_VERSION] go/[GO_VERSION] os/[OS] cmd/current-user_me cmd-exec-id/[UUID] auth/pat"
|
||||
]
|
||||
},
|
||||
"method": "GET",
|
||||
"path": "/api/2.0/preview/scim/v2/Me"
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"id":"[USERID]",
|
||||
"userName":"[USERNAME]"
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
export DATABRICKS_TOKEN=dapi1234
|
||||
|
||||
$CLI current-user me
|
|
@ -0,0 +1,20 @@
|
|||
LocalOnly = true
|
||||
|
||||
RecordRequests = true
|
||||
IncludeRequestHeaders = ["Authorization", "User-Agent"]
|
||||
|
||||
[[Repls]]
|
||||
Old = '(linux|darwin|windows)'
|
||||
New = '[OS]'
|
||||
|
||||
[[Repls]]
|
||||
Old = " upstream/[A-Za-z0-9.-]+"
|
||||
New = ""
|
||||
|
||||
[[Repls]]
|
||||
Old = " upstream-version/[A-Za-z0-9.-]+"
|
||||
New = ""
|
||||
|
||||
[[Repls]]
|
||||
Old = " cicd/[A-Za-z0-9.-]+"
|
||||
New = ""
|
|
@ -14,11 +14,7 @@ import (
|
|||
|
||||
func StartCmdServer(t *testing.T) *testserver.Server {
|
||||
server := testserver.New(t)
|
||||
|
||||
// {$} is a wildcard that only matches the end of the URL. We explicitly use
|
||||
// /{$} to disambiguate it from the generic handler for '/' which is used to
|
||||
// identify unhandled API endpoints in the test server.
|
||||
server.Handle("/{$}", func(w *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/", func(_ *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
q := r.URL.Query()
|
||||
args := strings.Split(q.Get("args"), " ")
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"method": "GET",
|
||||
"path": "/api/2.0/preview/scim/v2/Me"
|
||||
}
|
||||
{
|
||||
"method": "GET",
|
||||
"path": "/custom/endpoint"
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
>>> curl -s [DATABRICKS_URL]/api/2.0/preview/scim/v2/Me
|
||||
{
|
||||
"id": "[USERID]",
|
||||
"userName": "[USERNAME]"
|
||||
}
|
||||
>>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param
|
||||
HTTP/1.1 201 Created
|
||||
Content-Type: application/json
|
||||
Date: (redacted)
|
||||
Content-Length: (redacted)
|
||||
|
||||
custom
|
||||
---
|
||||
response
|
|
@ -0,0 +1,2 @@
|
|||
trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me
|
||||
trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param
|
|
@ -0,0 +1,18 @@
|
|||
LocalOnly = true
|
||||
RecordRequests = true
|
||||
|
||||
[[Server]]
|
||||
Pattern = "GET /custom/endpoint"
|
||||
Response.Body = '''custom
|
||||
---
|
||||
response
|
||||
'''
|
||||
Response.StatusCode = 201
|
||||
|
||||
[[Repls]]
|
||||
Old = 'Date: .*'
|
||||
New = 'Date: (redacted)'
|
||||
|
||||
[[Repls]]
|
||||
Old = 'Content-Length: [0-9]*'
|
||||
New = 'Content-Length: (redacted)'
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/databricks/databricks-sdk-go/service/catalog"
|
||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/compute"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
|
@ -16,8 +17,13 @@ import (
|
|||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
)
|
||||
|
||||
var testUser = iam.User{
|
||||
Id: "1000012345",
|
||||
UserName: "tester@databricks.com",
|
||||
}
|
||||
|
||||
func AddHandlers(server *testserver.Server) {
|
||||
server.Handle("GET /api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return compute.ListPoliciesResponse{
|
||||
Policies: []compute.Policy{
|
||||
{
|
||||
|
@ -32,7 +38,7 @@ func AddHandlers(server *testserver.Server) {
|
|||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return compute.ListInstancePools{
|
||||
InstancePools: []compute.InstancePoolAndStats{
|
||||
{
|
||||
|
@ -43,7 +49,7 @@ func AddHandlers(server *testserver.Server) {
|
|||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return compute.ListClustersResponse{
|
||||
Clusters: []compute.ClusterDetails{
|
||||
{
|
||||
|
@ -58,20 +64,17 @@ func AddHandlers(server *testserver.Server) {
|
|||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return iam.User{
|
||||
Id: "1000012345",
|
||||
UserName: "tester@databricks.com",
|
||||
}, http.StatusOK
|
||||
server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return testUser, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
path := r.URL.Query().Get("path")
|
||||
|
||||
return fakeWorkspace.WorkspaceGetStatus(path)
|
||||
})
|
||||
|
||||
server.Handle("POST /api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("POST", "/api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
request := workspace.Mkdirs{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
|
||||
|
@ -83,13 +86,13 @@ func AddHandlers(server *testserver.Server) {
|
|||
return fakeWorkspace.WorkspaceMkdirs(request)
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
path := r.URL.Query().Get("path")
|
||||
|
||||
return fakeWorkspace.WorkspaceExport(path)
|
||||
})
|
||||
|
||||
server.Handle("POST /api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("POST", "/api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
path := r.URL.Query().Get("path")
|
||||
recursiveStr := r.URL.Query().Get("recursive")
|
||||
var recursive bool
|
||||
|
@ -103,8 +106,9 @@ func AddHandlers(server *testserver.Server) {
|
|||
return fakeWorkspace.WorkspaceDelete(path, recursive)
|
||||
})
|
||||
|
||||
server.Handle("POST /api/2.0/workspace-files/import-file/{path}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
path := r.PathValue("path")
|
||||
server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
vars := mux.Vars(r)
|
||||
path := vars["path"]
|
||||
|
||||
body := new(bytes.Buffer)
|
||||
_, err := body.ReadFrom(r.Body)
|
||||
|
@ -115,14 +119,15 @@ func AddHandlers(server *testserver.Server) {
|
|||
return fakeWorkspace.WorkspaceFilesImportFile(path, body.Bytes())
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return catalog.MetastoreAssignment{
|
||||
DefaultCatalogName: "main",
|
||||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
objectId := r.PathValue("objectId")
|
||||
server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
vars := mux.Vars(r)
|
||||
objectId := vars["objectId"]
|
||||
|
||||
return workspace.WorkspaceObjectPermissions{
|
||||
ObjectId: objectId,
|
||||
|
@ -140,7 +145,7 @@ func AddHandlers(server *testserver.Server) {
|
|||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("POST /api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("POST", "/api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
request := jobs.CreateJob{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
|
||||
|
@ -152,15 +157,31 @@ func AddHandlers(server *testserver.Server) {
|
|||
return fakeWorkspace.JobsCreate(request)
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
jobId := r.URL.Query().Get("job_id")
|
||||
|
||||
return fakeWorkspace.JobsGet(jobId)
|
||||
})
|
||||
|
||||
server.Handle("GET /api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
server.Handle("GET", "/api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return fakeWorkspace.JobsList()
|
||||
})
|
||||
|
||||
server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return map[string]string{
|
||||
"authorization_endpoint": server.URL + "oidc/v1/authorize",
|
||||
"token_endpoint": server.URL + "/oidc/v1/token",
|
||||
}, http.StatusOK
|
||||
})
|
||||
|
||||
server.Handle("POST", "/oidc/v1/token", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) {
|
||||
return map[string]string{
|
||||
"access_token": "oauth-token",
|
||||
"expires_in": "3600",
|
||||
"scope": "all-apis",
|
||||
"token_type": "Bearer",
|
||||
}, http.StatusOK
|
||||
})
|
||||
}
|
||||
|
||||
func internalError(err error) (any, int) {
|
||||
|
|
|
@ -31,6 +31,7 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
|
|||
cmd.AddCommand(newProfilesCommand())
|
||||
cmd.AddCommand(newTokenCommand(&perisistentAuth))
|
||||
cmd.AddCommand(newDescribeCommand())
|
||||
cmd.AddCommand(newLogoutCommand(&perisistentAuth))
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
|
||||
"github.com/databricks/cli/libs/auth"
|
||||
"github.com/databricks/cli/libs/auth/cache"
|
||||
"github.com/databricks/cli/libs/cmdio"
|
||||
"github.com/databricks/cli/libs/databrickscfg/profile"
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type logoutSession struct {
|
||||
profile string
|
||||
file config.File
|
||||
persistentAuth *auth.PersistentAuth
|
||||
}
|
||||
|
||||
func (l *logoutSession) load(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth) error {
|
||||
l.profile = profileName
|
||||
l.persistentAuth = persistentAuth
|
||||
iniFile, err := profile.DefaultProfiler.Get(ctx)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("cannot parse config file: %w", err)
|
||||
}
|
||||
l.file = *iniFile
|
||||
if err := l.setHostAndAccountIdFromProfile(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logoutSession) setHostAndAccountIdFromProfile() error {
|
||||
sectionMap, err := l.getConfigSectionMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sectionMap["host"] == "" {
|
||||
return fmt.Errorf("no host configured for profile %s", l.profile)
|
||||
}
|
||||
l.persistentAuth.Host = sectionMap["host"]
|
||||
l.persistentAuth.AccountID = sectionMap["account_id"]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logoutSession) getConfigSectionMap() (map[string]string, error) {
|
||||
section, err := l.file.GetSection(l.profile)
|
||||
if err != nil {
|
||||
return map[string]string{}, fmt.Errorf("profile does not exist in config file: %w", err)
|
||||
}
|
||||
return section.KeysHash(), nil
|
||||
}
|
||||
|
||||
// clear token from ~/.databricks/token-cache.json
|
||||
func (l *logoutSession) clearTokenCache(ctx context.Context) error {
|
||||
return l.persistentAuth.ClearToken(ctx)
|
||||
}
|
||||
|
||||
func newLogoutCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "logout [PROFILE]",
|
||||
Short: "Logout from specified profile",
|
||||
Long: "Removes the OAuth token from the token-cache",
|
||||
}
|
||||
|
||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
profileNameFromFlag := cmd.Flag("profile").Value.String()
|
||||
// If both [PROFILE] and --profile are provided, return an error.
|
||||
if len(args) > 0 && profileNameFromFlag != "" {
|
||||
return fmt.Errorf("please only provide a profile as an argument or a flag, not both")
|
||||
}
|
||||
// Determine the profile name from either args or the flag.
|
||||
profileName := profileNameFromFlag
|
||||
if len(args) > 0 {
|
||||
profileName = args[0]
|
||||
}
|
||||
// If the user has not specified a profile name, prompt for one.
|
||||
if profileName == "" {
|
||||
var err error
|
||||
profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer persistentAuth.Close()
|
||||
logoutSession := &logoutSession{}
|
||||
err := logoutSession.load(ctx, profileName, persistentAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = logoutSession.clearTokenCache(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, cache.ErrNotConfigured) {
|
||||
// It is OK to not have OAuth configured
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cmdio.LogString(ctx, fmt.Sprintf("Profile %s is logged out", profileName))
|
||||
return nil
|
||||
}
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/databricks/cli/libs/auth"
|
||||
"github.com/databricks/cli/libs/databrickscfg"
|
||||
"github.com/databricks/databricks-sdk-go/config"
|
||||
)
|
||||
|
||||
func TestLogout_setHostAndAccountIdFromProfile(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
path := filepath.Join(t.TempDir(), "databrickscfg")
|
||||
|
||||
err := databrickscfg.SaveToProfile(ctx, &config.Config{
|
||||
ConfigFile: path,
|
||||
Profile: "abc",
|
||||
Host: "https://foo",
|
||||
Token: "xyz",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
iniFile, err := config.LoadFile(path)
|
||||
require.NoError(t, err)
|
||||
logout := &logoutSession{
|
||||
profile: "abc",
|
||||
file: *iniFile,
|
||||
persistentAuth: &auth.PersistentAuth{},
|
||||
}
|
||||
err = logout.setHostAndAccountIdFromProfile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, logout.persistentAuth.Host, "https://foo")
|
||||
assert.Empty(t, logout.persistentAuth.AccountID)
|
||||
}
|
||||
|
||||
func TestLogout_getConfigSectionMap(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
path := filepath.Join(t.TempDir(), "databrickscfg")
|
||||
|
||||
err := databrickscfg.SaveToProfile(ctx, &config.Config{
|
||||
ConfigFile: path,
|
||||
Profile: "abc",
|
||||
Host: "https://foo",
|
||||
Token: "xyz",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
iniFile, err := config.LoadFile(path)
|
||||
require.NoError(t, err)
|
||||
logout := &logoutSession{
|
||||
profile: "abc",
|
||||
file: *iniFile,
|
||||
persistentAuth: &auth.PersistentAuth{},
|
||||
}
|
||||
configSectionMap, err := logout.getConfigSectionMap()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, configSectionMap["host"], "https://foo")
|
||||
assert.Equal(t, configSectionMap["token"], "xyz")
|
||||
}
|
1
go.mod
1
go.mod
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/databricks/databricks-sdk-go v0.57.0 // Apache 2.0
|
||||
github.com/fatih/color v1.18.0 // MIT
|
||||
github.com/google/uuid v1.6.0 // BSD-3-Clause
|
||||
github.com/gorilla/mux v1.8.1 // BSD 3-Clause
|
||||
github.com/hashicorp/go-version v1.7.0 // MPL 2.0
|
||||
github.com/hashicorp/hc-install v0.9.1 // MPL 2.0
|
||||
github.com/hashicorp/terraform-exec v0.22.0 // MPL 2.0
|
||||
|
|
|
@ -97,6 +97,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
|
|||
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
|
||||
github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg=
|
||||
github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
|
||||
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
type TokenCache interface {
|
||||
Store(key string, t *oauth2.Token) error
|
||||
Lookup(key string) (*oauth2.Token, error)
|
||||
Delete(key string) error
|
||||
}
|
||||
|
||||
var tokenCache int
|
||||
|
|
|
@ -52,11 +52,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error {
|
|||
c.Tokens = map[string]*oauth2.Token{}
|
||||
}
|
||||
c.Tokens[key] = t
|
||||
raw, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
||||
return c.write()
|
||||
}
|
||||
|
||||
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
||||
|
@ -73,6 +69,24 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func (c *FileTokenCache) Delete(key string) error {
|
||||
err := c.load()
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return ErrNotConfigured
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("load: %w", err)
|
||||
}
|
||||
if c.Tokens == nil {
|
||||
c.Tokens = map[string]*oauth2.Token{}
|
||||
}
|
||||
_, ok := c.Tokens[key]
|
||||
if !ok {
|
||||
return ErrNotConfigured
|
||||
}
|
||||
delete(c.Tokens, key)
|
||||
return c.write()
|
||||
}
|
||||
|
||||
func (c *FileTokenCache) location() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
@ -105,4 +119,12 @@ func (c *FileTokenCache) load() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *FileTokenCache) write() error {
|
||||
raw, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
|
||||
}
|
||||
|
||||
var _ TokenCache = (*FileTokenCache)(nil)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
@ -103,3 +104,64 @@ func TestStoreOnDev(t *testing.T) {
|
|||
// macOS: read-only file system
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoreAndDeleteKey(t *testing.T) {
|
||||
setup(t)
|
||||
c := &FileTokenCache{}
|
||||
err := c.Store("x", &oauth2.Token{
|
||||
AccessToken: "abc",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Store("y", &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
l := &FileTokenCache{}
|
||||
err = l.Delete("x")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(l.Tokens))
|
||||
|
||||
_, err = l.Lookup("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
|
||||
tok, err := l.Lookup("y")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bcd", tok.AccessToken)
|
||||
}
|
||||
|
||||
func TestDeleteKeyNotExist(t *testing.T) {
|
||||
c := &FileTokenCache{
|
||||
Tokens: map[string]*oauth2.Token{},
|
||||
}
|
||||
err := c.Delete("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
|
||||
_, err = c.Lookup("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
tempFile := filepath.Join(t.TempDir(), "token-cache.json")
|
||||
|
||||
tokenMap := map[string]*oauth2.Token{}
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "some-access-token",
|
||||
}
|
||||
tokenMap["test"] = token
|
||||
|
||||
cache := &FileTokenCache{
|
||||
fileLocation: tempFile,
|
||||
Tokens: tokenMap,
|
||||
}
|
||||
|
||||
err := cache.write()
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tempFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, _ := json.MarshalIndent(&cache, "", " ")
|
||||
assert.Equal(t, content, expected)
|
||||
}
|
||||
|
|
|
@ -23,4 +23,14 @@ func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Delete implements TokenCache.
|
||||
func (i *InMemoryTokenCache) Delete(key string) error {
|
||||
_, ok := i.Tokens[key]
|
||||
if !ok {
|
||||
return ErrNotConfigured
|
||||
}
|
||||
delete(i.Tokens, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ TokenCache = (*InMemoryTokenCache)(nil)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
|
@ -42,3 +43,40 @@ func TestInMemoryCacheStore(t *testing.T) {
|
|||
assert.Equal(t, res, token)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInMemoryDeleteKey(t *testing.T) {
|
||||
c := &InMemoryTokenCache{
|
||||
Tokens: map[string]*oauth2.Token{},
|
||||
}
|
||||
err := c.Store("x", &oauth2.Token{
|
||||
AccessToken: "abc",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Store("y", &oauth2.Token{
|
||||
AccessToken: "bcd",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Delete("x")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(c.Tokens))
|
||||
|
||||
_, err = c.Lookup("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
|
||||
tok, err := c.Lookup("y")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bcd", tok.AccessToken)
|
||||
}
|
||||
|
||||
func TestInMemoryDeleteKeyNotExist(t *testing.T) {
|
||||
c := &InMemoryTokenCache{
|
||||
Tokens: map[string]*oauth2.Token{},
|
||||
}
|
||||
err := c.Delete("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
|
||||
_, err = c.Lookup("x")
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
|
|
@ -144,6 +144,18 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *PersistentAuth) ClearToken(ctx context.Context) error {
|
||||
if a.Host == "" && a.AccountID == "" {
|
||||
return ErrFetchCredentials
|
||||
}
|
||||
if a.cache == nil {
|
||||
a.cache = cache.GetTokenCache(ctx)
|
||||
}
|
||||
// lookup token identified by host (and possibly the account id)
|
||||
key := a.key()
|
||||
return a.cache.Delete(key)
|
||||
}
|
||||
|
||||
// This function cleans up the host URL by only retaining the scheme and the host.
|
||||
// This function thus removes any path, query arguments, or fragments from the URL.
|
||||
func (a *PersistentAuth) cleanHost() {
|
||||
|
|
|
@ -56,6 +56,7 @@ func TestOidcForWorkspace(t *testing.T) {
|
|||
type tokenCacheMock struct {
|
||||
store func(key string, t *oauth2.Token) error
|
||||
lookup func(key string) (*oauth2.Token, error)
|
||||
delete func(key string) error
|
||||
}
|
||||
|
||||
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
|
||||
|
@ -72,6 +73,13 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
|
|||
return m.lookup(key)
|
||||
}
|
||||
|
||||
func (m *tokenCacheMock) Delete(key string) error {
|
||||
if m.delete == nil {
|
||||
panic("no deleteKey mock")
|
||||
}
|
||||
return m.delete(key)
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
|
@ -232,6 +240,52 @@ func TestChallengeFailed(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestClearToken(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return &oauth2.Token{}, ErrNotConfigured
|
||||
},
|
||||
delete: func(key string) error {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
err := p.ClearToken(context.Background())
|
||||
assert.NoError(t, err)
|
||||
key := p.key()
|
||||
_, err = p.cache.Lookup(key)
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
||||
func TestClearTokenNotExist(t *testing.T) {
|
||||
p := &PersistentAuth{
|
||||
Host: "abc",
|
||||
AccountID: "xyz",
|
||||
cache: &tokenCacheMock{
|
||||
lookup: func(key string) (*oauth2.Token, error) {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return &oauth2.Token{}, ErrNotConfigured
|
||||
},
|
||||
delete: func(key string) error {
|
||||
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||||
return ErrNotConfigured
|
||||
},
|
||||
},
|
||||
}
|
||||
defer p.Close()
|
||||
err := p.ClearToken(context.Background())
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
key := p.key()
|
||||
_, err = p.cache.Lookup(key)
|
||||
assert.Equal(t, ErrNotConfigured, err)
|
||||
}
|
||||
|
||||
func TestPersistentAuthCleanHost(t *testing.T) {
|
||||
for _, tcases := range []struct {
|
||||
in string
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/databricks/cli/internal/testutil"
|
||||
|
@ -17,7 +19,7 @@ import (
|
|||
|
||||
type Server struct {
|
||||
*httptest.Server
|
||||
Mux *http.ServeMux
|
||||
Router *mux.Router
|
||||
|
||||
t testutil.TestingT
|
||||
|
||||
|
@ -34,26 +36,25 @@ type Request struct {
|
|||
Headers http.Header `json:"headers,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Body any `json:"body"`
|
||||
Body any `json:"body,omitempty"`
|
||||
RawBody string `json:"raw_body,omitempty"`
|
||||
}
|
||||
|
||||
func New(t testutil.TestingT) *Server {
|
||||
mux := http.NewServeMux()
|
||||
server := httptest.NewServer(mux)
|
||||
router := mux.NewRouter()
|
||||
server := httptest.NewServer(router)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
s := &Server{
|
||||
Server: server,
|
||||
Mux: mux,
|
||||
Router: router,
|
||||
t: t,
|
||||
mu: &sync.Mutex{},
|
||||
fakeWorkspaces: map[string]*FakeWorkspace{},
|
||||
}
|
||||
|
||||
// The server resolves conflicting handlers by using the one with higher
|
||||
// specificity. This handler is the least specific, so it will be used as a
|
||||
// fallback when no other handlers match.
|
||||
s.Handle("/", func(fakeWorkspace *FakeWorkspace, r *http.Request) (any, int) {
|
||||
// Set up the not found handler as fallback
|
||||
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
pattern := r.Method + " " + r.URL.Path
|
||||
|
||||
t.Errorf(`
|
||||
|
@ -74,9 +75,22 @@ Response.StatusCode = <response status-code here>
|
|||
|
||||
`, pattern, pattern)
|
||||
|
||||
return apierr.APIError{
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
|
||||
resp := apierr.APIError{
|
||||
Message: "No stub found for pattern: " + pattern,
|
||||
}, http.StatusNotImplemented
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Errorf("JSON encoding error: %s", err)
|
||||
respBytes = []byte("{\"message\": \"JSON encoding error\"}")
|
||||
}
|
||||
|
||||
if _, err := w.Write(respBytes); err != nil {
|
||||
t.Errorf("Response write error: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
return s
|
||||
|
@ -84,8 +98,8 @@ Response.StatusCode = <response status-code here>
|
|||
|
||||
type HandlerFunc func(fakeWorkspace *FakeWorkspace, req *http.Request) (resp any, statusCode int)
|
||||
|
||||
func (s *Server) Handle(pattern string, handler HandlerFunc) {
|
||||
s.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) Handle(method, path string, handler HandlerFunc) {
|
||||
s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
// For simplicity we process requests sequentially. It's fast enough because
|
||||
// we don't do any IO except reading and writing request/response bodies.
|
||||
s.mu.Lock()
|
||||
|
@ -119,13 +133,19 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) {
|
|||
}
|
||||
}
|
||||
|
||||
s.Requests = append(s.Requests, Request{
|
||||
req := Request{
|
||||
Headers: headers,
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Body: json.RawMessage(body),
|
||||
})
|
||||
}
|
||||
|
||||
if json.Valid(body) {
|
||||
req.Body = json.RawMessage(body)
|
||||
} else {
|
||||
req.RawBody = string(body)
|
||||
}
|
||||
|
||||
s.Requests = append(s.Requests, req)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
@ -149,7 +169,7 @@ func (s *Server) Handle(pattern string, handler HandlerFunc) {
|
|||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
}).Methods(method)
|
||||
}
|
||||
|
||||
func getToken(r *http.Request) string {
|
||||
|
|
Loading…
Reference in New Issue