package auth

import (
	"context"
	"crypto/tls"
	_ "embed"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strings"
	"testing"
	"time"

	"github.com/databricks/databricks-sdk-go/client"
	"github.com/databricks/databricks-sdk-go/qa"
	"github.com/stretchr/testify/assert"
	"golang.org/x/oauth2"
)

func TestOidcEndpointsForAccounts(t *testing.T) {
	p := &PersistentAuth{
		Host:      "abc",
		AccountID: "xyz",
	}
	defer p.Close()
	s, err := p.oidcEndpoints()
	assert.NoError(t, err)
	assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
	assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
}

type mockGet func(url string) (*http.Response, error)

func (m mockGet) Get(url string) (*http.Response, error) {
	return m(url)
}

func TestOidcForWorkspace(t *testing.T) {
	p := &PersistentAuth{
		Host: "abc",
		http: mockGet(func(url string) (*http.Response, error) {
			assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
			return &http.Response{
				StatusCode: 200,
				Body: io.NopCloser(strings.NewReader(`{
					"authorization_endpoint": "a",
					"token_endpoint": "b"
				}`)),
			}, nil
		}),
	}
	defer p.Close()
	endpoints, err := p.oidcEndpoints()
	assert.NoError(t, err)
	assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
	assert.Equal(t, "b", endpoints.TokenEndpoint)
}

type tokenCacheMock struct {
	store  func(key string, t *oauth2.Token) error
	lookup func(key string) (*oauth2.Token, error)
}

func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
	if m.store == nil {
		panic("no store mock")
	}
	return m.store(key, t)
}

func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
	if m.lookup == nil {
		panic("no lookup mock")
	}
	return m.lookup(key)
}

func TestLoad(t *testing.T) {
	p := &PersistentAuth{
		Host:      "abc",
		AccountID: "xyz",
		cache: &tokenCacheMock{
			lookup: func(key string) (*oauth2.Token, error) {
				assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
				return &oauth2.Token{
					AccessToken: "bcd",
					Expiry:      time.Now().Add(1 * time.Minute),
				}, nil
			},
		},
	}
	defer p.Close()
	tok, err := p.Load(context.Background())
	assert.NoError(t, err)
	assert.Equal(t, "bcd", tok.AccessToken)
	assert.Equal(t, "", tok.RefreshToken)
}

func useInsecureOAuthHttpClientForTests(ctx context.Context) context.Context {
	return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
	})
}

func TestLoadRefresh(t *testing.T) {
	qa.HTTPFixtures{
		{
			Method:   "POST",
			Resource: "/oidc/accounts/xyz/v1/token",
			Response: `access_token=refreshed&refresh_token=def`,
		},
	}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
		ctx = useInsecureOAuthHttpClientForTests(ctx)
		expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)
		p := &PersistentAuth{
			Host:      c.Config.Host,
			AccountID: "xyz",
			cache: &tokenCacheMock{
				lookup: func(key string) (*oauth2.Token, error) {
					assert.Equal(t, expectedKey, key)
					return &oauth2.Token{
						AccessToken:  "expired",
						RefreshToken: "cde",
						Expiry:       time.Now().Add(-1 * time.Minute),
					}, nil
				},
				store: func(key string, tok *oauth2.Token) error {
					assert.Equal(t, expectedKey, key)
					assert.Equal(t, "def", tok.RefreshToken)
					return nil
				},
			},
		}
		defer p.Close()
		tok, err := p.Load(ctx)
		assert.NoError(t, err)
		assert.Equal(t, "refreshed", tok.AccessToken)
		assert.Equal(t, "", tok.RefreshToken)
	})
}

func TestChallenge(t *testing.T) {
	qa.HTTPFixtures{
		{
			Method:   "POST",
			Resource: "/oidc/accounts/xyz/v1/token",
			Response: `access_token=__THAT__&refresh_token=__SOMETHING__`,
		},
	}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
		ctx = useInsecureOAuthHttpClientForTests(ctx)
		expectedKey := fmt.Sprintf("%s/oidc/accounts/xyz", c.Config.Host)

		browserOpened := make(chan string)
		p := &PersistentAuth{
			Host:      c.Config.Host,
			AccountID: "xyz",
			browser: func(redirect string) error {
				u, err := url.ParseRequestURI(redirect)
				if err != nil {
					return err
				}
				assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
				// for now we're ignoring asserting the fields of the redirect
				query := u.Query()
				browserOpened <- query.Get("state")
				return nil
			},
			cache: &tokenCacheMock{
				store: func(key string, tok *oauth2.Token) error {
					assert.Equal(t, expectedKey, key)
					assert.Equal(t, "__SOMETHING__", tok.RefreshToken)
					return nil
				},
			},
		}
		defer p.Close()

		errc := make(chan error)
		go func() {
			errc <- p.Challenge(ctx)
		}()

		state := <-browserOpened
		resp, err := http.Get(fmt.Sprintf("http://%s?code=__THIS__&state=%s", appRedirectAddr, state))
		assert.NoError(t, err)
		assert.Equal(t, 200, resp.StatusCode)

		err = <-errc
		assert.NoError(t, err)
	})
}

func TestChallengeFailed(t *testing.T) {
	qa.HTTPFixtures{}.ApplyClient(t, func(ctx context.Context, c *client.DatabricksClient) {
		ctx = useInsecureOAuthHttpClientForTests(ctx)

		browserOpened := make(chan string)
		p := &PersistentAuth{
			Host:      c.Config.Host,
			AccountID: "xyz",
			browser: func(redirect string) error {
				u, err := url.ParseRequestURI(redirect)
				if err != nil {
					return err
				}
				assert.Equal(t, "/oidc/accounts/xyz/v1/authorize", u.Path)
				// for now we're ignoring asserting the fields of the redirect
				query := u.Query()
				browserOpened <- query.Get("state")
				return nil
			},
		}
		defer p.Close()

		errc := make(chan error)
		go func() {
			errc <- p.Challenge(ctx)
		}()

		<-browserOpened
		resp, err := http.Get(fmt.Sprintf(
			"http://%s?error=access_denied&error_description=Policy%%20evaluation%%20failed%%20for%%20this%%20request",
			appRedirectAddr))
		assert.NoError(t, err)
		assert.Equal(t, 400, resp.StatusCode)

		err = <-errc
		assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request")
	})
}