mirror of https://github.com/databricks/cli.git
236 lines
6.0 KiB
Go
236 lines
6.0 KiB
Go
|
package auth
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
_ "embed"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/databricks/databricks-sdk-go/client"
|
||
|
"github.com/databricks/databricks-sdk-go/qa"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"golang.org/x/oauth2"
|
||
|
)
|
||
|
|
||
|
func TestOidcEndpointsForAccounts(t *testing.T) {
|
||
|
p := &PersistentAuth{
|
||
|
Host: "abc",
|
||
|
AccountID: "xyz",
|
||
|
}
|
||
|
defer p.Close()
|
||
|
s, err := p.oidcEndpoints()
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
|
||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
|
||
|
}
|
||
|
|
||
|
type mockGet func(url string) (*http.Response, error)
|
||
|
|
||
|
func (m mockGet) Get(url string) (*http.Response, error) {
|
||
|
return m(url)
|
||
|
}
|
||
|
|
||
|
func TestOidcForWorkspace(t *testing.T) {
|
||
|
p := &PersistentAuth{
|
||
|
Host: "abc",
|
||
|
http: mockGet(func(url string) (*http.Response, error) {
|
||
|
assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
|
||
|
return &http.Response{
|
||
|
StatusCode: 200,
|
||
|
Body: io.NopCloser(strings.NewReader(`{
|
||
|
"authorization_endpoint": "a",
|
||
|
"token_endpoint": "b"
|
||
|
}`)),
|
||
|
}, nil
|
||
|
}),
|
||
|
}
|
||
|
defer p.Close()
|
||
|
endpoints, err := p.oidcEndpoints()
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
|
||
|
assert.Equal(t, "b", endpoints.TokenEndpoint)
|
||
|
}
|
||
|
|
||
|
type tokenCacheMock struct {
|
||
|
store func(key string, t *oauth2.Token) error
|
||
|
lookup func(key string) (*oauth2.Token, error)
|
||
|
}
|
||
|
|
||
|
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
|
||
|
if m.store == nil {
|
||
|
panic("no store mock")
|
||
|
}
|
||
|
return m.store(key, t)
|
||
|
}
|
||
|
|
||
|
func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
|
||
|
if m.lookup == nil {
|
||
|
panic("no lookup mock")
|
||
|
}
|
||
|
return m.lookup(key)
|
||
|
}
|
||
|
|
||
|
func TestLoad(t *testing.T) {
|
||
|
p := &PersistentAuth{
|
||
|
Host: "abc",
|
||
|
AccountID: "xyz",
|
||
|
cache: &tokenCacheMock{
|
||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||
|
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
|
||
|
return &oauth2.Token{
|
||
|
AccessToken: "bcd",
|
||
|
Expiry: time.Now().Add(1 * time.Minute),
|
||
|
}, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
defer p.Close()
|
||
|
tok, err := p.Load(context.Background())
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, "bcd", tok.AccessToken)
|
||
|
assert.Equal(t, "", tok.RefreshToken)
|
||
|
}
|
||
|
|
||
|
func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context {
|
||
|
return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
|
||
|
Transport: &http.Transport{
|
||
|
TLSClientConfig: &tls.Config{
|
||
|
InsecureSkipVerify: true,
|
||
|
},
|
||
|
},
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestLoadRefresh(t *testing.T) {
|
||
|
qa.HTTPFixtures{
|
||
|
{
|
||
|
Method: "POST",
|
||
|
Resource: "/oidc/accounts/xyz/v1/token",
|
||
|
Response: `access_token=refreshed&refresh_token=def`,
|
||
|
},
|
||
|
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||
|
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||
|
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
|
||
|
p := &PersistentAuth{
|
||
|
Host: c.Config.Host,
|
||
|
AccountID: "xyz",
|
||
|
cache: &tokenCacheMock{
|
||
|
lookup: func(key string) (*oauth2.Token, error) {
|
||
|
assert.Equal(t, expectedKey, key)
|
||
|
return &oauth2.Token{
|
||
|
AccessToken: "expired",
|
||
|
RefreshToken: "cde",
|
||
|
Expiry: time.Now().Add(-1 * time.Minute),
|
||
|
}, nil
|
||
|
},
|
||
|
store: func(key string, tok *oauth2.Token) error {
|
||
|
assert.Equal(t, expectedKey, key)
|
||
|
assert.Equal(t, "def", tok.RefreshToken)
|
||
|
return nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
defer p.Close()
|
||
|
tok, err := p.Load(ctx)
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, "refreshed", tok.AccessToken)
|
||
|
assert.Equal(t, "", tok.RefreshToken)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestChallenge(t *testing.T) {
|
||
|
qa.HTTPFixtures{
|
||
|
{
|
||
|
Method: "POST",
|
||
|
Resource: "/oidc/accounts/xyz/v1/token",
|
||
|
Response: `access_token=__THAT__&refresh_token=__SOMETHING__`,
|
||
|
},
|
||
|
}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||
|
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||
|
expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
|
||
|
|
||
|
browserOpened := make(chan string)
|
||
|
p := &PersistentAuth{
|
||
|
Host: c.Config.Host,
|
||
|
AccountID: "xyz",
|
||
|
browser: func(redirect string) error {
|
||
|
u, err := url.ParseRequestURI(redirect)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
|
||
|
// for now we're ignoring asserting the fields of the redirect
|
||
|
query := u.Query()
|
||
|
browserOpened <- query.Get("state")
|
||
|
return nil
|
||
|
},
|
||
|
cache: &tokenCacheMock{
|
||
|
store: func(key string, tok *oauth2.Token) error {
|
||
|
assert.Equal(t, expectedKey, key)
|
||
|
assert.Equal(t, "__SOMETHING__", tok.RefreshToken)
|
||
|
return nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
defer p.Close()
|
||
|
|
||
|
errc := make(chan error)
|
||
|
go func() {
|
||
|
errc <- p.Challenge(ctx)
|
||
|
}()
|
||
|
|
||
|
state := <-browserOpened
|
||
|
resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, 200, resp.StatusCode)
|
||
|
|
||
|
err = <-errc
|
||
|
assert.NoError(t, err)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestChallengeFailed(t *testing.T) {
|
||
|
qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
|
||
|
ctx = useInsecureOAuthHttpClientForTests(ctx)
|
||
|
|
||
|
browserOpened := make(chan string)
|
||
|
p := &PersistentAuth{
|
||
|
Host: c.Config.Host,
|
||
|
AccountID: "xyz",
|
||
|
browser: func(redirect string) error {
|
||
|
u, err := url.ParseRequestURI(redirect)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
|
||
|
// for now we're ignoring asserting the fields of the redirect
|
||
|
query := u.Query()
|
||
|
browserOpened <- query.Get("state")
|
||
|
return nil
|
||
|
},
|
||
|
}
|
||
|
defer p.Close()
|
||
|
|
||
|
errc := make(chan error)
|
||
|
go func() {
|
||
|
errc <- p.Challenge(ctx)
|
||
|
}()
|
||
|
|
||
|
<-browserOpened
|
||
|
resp, err := http.Get(fmt.Sprintf(
|
||
|
"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
|
||
|
appRedirectAddr))
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, 400, resp.StatusCode)
|
||
|
|
||
|
err = <-errc
|
||
|
assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
|
||
|
})
|
||
|
}
|