mirror of https://github.com/databricks/cli.git
Compare commits
9 Commits
b6ab3cf4ba
...
a69dc7f002
Author | SHA1 | Date |
---|---|---|
|
a69dc7f002 | |
|
2b5cd6a2a7 | |
|
089ea4fa97 | |
|
4f6221e14b | |
|
1ce20a2612 | |
|
1306e5ec67 | |
|
d5d6d2f49a | |
|
261b7f4083 | |
|
e088d0d996 |
|
@ -0,0 +1 @@
|
||||||
|
* @pietern @andrewnester @shreyas-goenka @denik
|
2
Makefile
2
Makefile
|
@ -2,7 +2,7 @@ default: build
|
||||||
|
|
||||||
lint: vendor
|
lint: vendor
|
||||||
@echo "✓ Linting source code with https://golangci-lint.run/ (with --fix)..."
|
@echo "✓ Linting source code with https://golangci-lint.run/ (with --fix)..."
|
||||||
@golangci-lint run --fix ./...
|
@./lint.sh ./...
|
||||||
|
|
||||||
lintcheck: vendor
|
lintcheck: vendor
|
||||||
@echo "✓ Linting source code with https://golangci-lint.run/ ..."
|
@echo "✓ Linting source code with https://golangci-lint.run/ ..."
|
||||||
|
|
|
@ -7,11 +7,11 @@ import (
|
||||||
"github.com/databricks/cli/libs/diag"
|
"github.com/databricks/cli/libs/diag"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FastValidate runs a set of fast validation checks. This is a subset of the full
|
// FastValidate runs a subset of fast validation checks. This is a subset of the full
|
||||||
// suite of validation mutators that satisfy ANY ONE of the following criteria:
|
// suite of validation mutators that satisfy ANY ONE of the following criteria:
|
||||||
//
|
//
|
||||||
// 1. No file i/o or network requests are made in the mutator.
|
// 1. No file i/o or network requests are made in the mutator.
|
||||||
// 2. Only returns errors which are blocking for a bundle deployment.
|
// 2. The validation is blocking for bundle deployments.
|
||||||
//
|
//
|
||||||
// The full suite of validation mutators is available in the [Validate] mutator.
|
// The full suite of validation mutators is available in the [Validate] mutator.
|
||||||
type fastValidateReadonly struct{}
|
type fastValidateReadonly struct{}
|
||||||
|
|
|
@ -13,9 +13,7 @@ import (
|
||||||
"github.com/databricks/cli/libs/diag"
|
"github.com/databricks/cli/libs/diag"
|
||||||
"github.com/databricks/cli/libs/dyn"
|
"github.com/databricks/cli/libs/dyn"
|
||||||
"github.com/databricks/cli/libs/dyn/dynvar"
|
"github.com/databricks/cli/libs/dyn/dynvar"
|
||||||
"github.com/databricks/cli/libs/log"
|
|
||||||
"github.com/databricks/databricks-sdk-go/apierr"
|
"github.com/databricks/databricks-sdk-go/apierr"
|
||||||
"github.com/databricks/databricks-sdk-go/service/catalog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type validateArtifactPath struct{}
|
type validateArtifactPath struct{}
|
||||||
|
@ -99,7 +97,7 @@ func (v *validateArtifactPath) Apply(ctx context.Context, rb bundle.ReadOnlyBund
|
||||||
}
|
}
|
||||||
volumeFullName := fmt.Sprintf("%s.%s.%s", catalogName, schemaName, volumeName)
|
volumeFullName := fmt.Sprintf("%s.%s.%s", catalogName, schemaName, volumeName)
|
||||||
w := rb.WorkspaceClient()
|
w := rb.WorkspaceClient()
|
||||||
p, err := w.Grants.GetEffectiveBySecurableTypeAndFullName(ctx, catalog.SecurableTypeVolume, volumeFullName)
|
_, err = w.Volumes.ReadByName(ctx, volumeFullName)
|
||||||
|
|
||||||
if errors.Is(err, apierr.ErrPermissionDenied) {
|
if errors.Is(err, apierr.ErrPermissionDenied) {
|
||||||
return wrapErrorMsg(fmt.Sprintf("cannot access volume %s: %s", volumeFullName, err))
|
return wrapErrorMsg(fmt.Sprintf("cannot access volume %s: %s", volumeFullName, err))
|
||||||
|
@ -125,31 +123,7 @@ the artifact_path.`,
|
||||||
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErrorMsg(fmt.Sprintf("could not fetch grants for volume %s: %s", volumeFullName, err))
|
return wrapErrorMsg(fmt.Sprintf("cannot read volume %s: %s", volumeFullName, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
allPrivileges := []catalog.Privilege{}
|
|
||||||
for _, assignments := range p.PrivilegeAssignments {
|
|
||||||
for _, privilege := range assignments.Privileges {
|
|
||||||
allPrivileges = append(allPrivileges, privilege.Privilege)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UC Volumes have the following privileges: [READ_VOLUME, WRITE_VOLUME, MANAGE, ALL_PRIVILEGES, APPLY TAG]
|
|
||||||
// The user needs to have either WRITE_VOLUME or ALL_PRIVILEGES to write to the volume.
|
|
||||||
canWrite := slices.Contains(allPrivileges, catalog.PrivilegeWriteVolume) || slices.Contains(allPrivileges, catalog.PrivilegeAllPrivileges)
|
|
||||||
if !canWrite {
|
|
||||||
log.Debugf(ctx, "Current privileges on Volume at artifact_path: %v", allPrivileges)
|
|
||||||
return wrapErrorMsg(fmt.Sprintf("user does not have WRITE_VOLUME grant on volume %s", volumeFullName))
|
|
||||||
}
|
|
||||||
|
|
||||||
// READ_VOLUME is implied since the user was able to fetch the associated grants with the volume.
|
|
||||||
// We still add this explicit check out of caution incase the API behavior changes in the future.
|
|
||||||
canRead := slices.Contains(allPrivileges, catalog.PrivilegeReadVolume) || slices.Contains(allPrivileges, catalog.PrivilegeAllPrivileges)
|
|
||||||
if !canRead {
|
|
||||||
log.Debugf(ctx, "Current privileges on Volume at artifact_path: %v", allPrivileges)
|
|
||||||
return wrapErrorMsg(fmt.Sprintf("user does not have READ_VOLUME grant on volume %s", volumeFullName))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,8 +45,8 @@ func TestValidateArtifactPathWithVolumeInBundle(t *testing.T) {
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
m := mocks.NewMockWorkspaceClient(t)
|
m := mocks.NewMockWorkspaceClient(t)
|
||||||
api := m.GetMockGrantsAPI()
|
api := m.GetMockVolumesAPI()
|
||||||
api.EXPECT().GetEffectiveBySecurableTypeAndFullName(mock.Anything, catalog.SecurableTypeVolume, "catalogN.schemaN.volumeN").Return(nil, &apierr.APIError{
|
api.EXPECT().ReadByName(mock.Anything, "catalogN.schemaN.volumeN").Return(nil, &apierr.APIError{
|
||||||
StatusCode: 404,
|
StatusCode: 404,
|
||||||
})
|
})
|
||||||
b.SetWorkpaceClient(m.WorkspaceClient)
|
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||||
|
@ -90,22 +90,11 @@ func TestValidateArtifactPath(t *testing.T) {
|
||||||
}}, diags)
|
}}, diags)
|
||||||
}
|
}
|
||||||
|
|
||||||
wrapPrivileges := func(privileges ...catalog.Privilege) *catalog.EffectivePermissionsList {
|
|
||||||
perms := &catalog.EffectivePermissionsList{}
|
|
||||||
for _, p := range privileges {
|
|
||||||
perms.PrivilegeAssignments = append(perms.PrivilegeAssignments, catalog.EffectivePrivilegeAssignment{
|
|
||||||
Privileges: []catalog.EffectivePrivilege{{Privilege: p}},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return perms
|
|
||||||
}
|
|
||||||
|
|
||||||
rb := bundle.ReadOnly(b)
|
rb := bundle.ReadOnly(b)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
tcases := []struct {
|
tcases := []struct {
|
||||||
err error
|
err error
|
||||||
permissions *catalog.EffectivePermissionsList
|
|
||||||
expectedSummary string
|
expectedSummary string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -126,36 +115,18 @@ func TestValidateArtifactPath(t *testing.T) {
|
||||||
StatusCode: 500,
|
StatusCode: 500,
|
||||||
Message: "Internal Server Error",
|
Message: "Internal Server Error",
|
||||||
},
|
},
|
||||||
expectedSummary: "could not fetch grants for volume catalogN.schemaN.volumeN: Internal Server Error",
|
expectedSummary: "cannot read volume catalogN.schemaN.volumeN: Internal Server Error",
|
||||||
},
|
|
||||||
{
|
|
||||||
permissions: wrapPrivileges(catalog.PrivilegeAllPrivileges),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
permissions: wrapPrivileges(catalog.PrivilegeApplyTag, catalog.PrivilegeManage),
|
|
||||||
expectedSummary: "user does not have WRITE_VOLUME grant on volume catalogN.schemaN.volumeN",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
permissions: wrapPrivileges(catalog.PrivilegeWriteVolume),
|
|
||||||
expectedSummary: "user does not have READ_VOLUME grant on volume catalogN.schemaN.volumeN",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
permissions: wrapPrivileges(catalog.PrivilegeWriteVolume, catalog.PrivilegeReadVolume),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tcases {
|
for _, tc := range tcases {
|
||||||
m := mocks.NewMockWorkspaceClient(t)
|
m := mocks.NewMockWorkspaceClient(t)
|
||||||
api := m.GetMockGrantsAPI()
|
api := m.GetMockVolumesAPI()
|
||||||
api.EXPECT().GetEffectiveBySecurableTypeAndFullName(mock.Anything, catalog.SecurableTypeVolume, "catalogN.schemaN.volumeN").Return(tc.permissions, tc.err)
|
api.EXPECT().ReadByName(mock.Anything, "catalogN.schemaN.volumeN").Return(nil, tc.err)
|
||||||
b.SetWorkpaceClient(m.WorkspaceClient)
|
b.SetWorkpaceClient(m.WorkspaceClient)
|
||||||
|
|
||||||
diags := bundle.ApplyReadOnly(ctx, rb, ValidateArtifactPath())
|
diags := bundle.ApplyReadOnly(ctx, rb, ValidateArtifactPath())
|
||||||
if tc.expectedSummary != "" {
|
|
||||||
assertDiags(t, diags, tc.expectedSummary)
|
assertDiags(t, diags, tc.expectedSummary)
|
||||||
} else {
|
|
||||||
assert.Len(t, diags, 0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/databricks/cli/internal/testcli"
|
"github.com/databricks/cli/internal/testcli"
|
||||||
"github.com/databricks/cli/internal/testutil"
|
"github.com/databricks/cli/internal/testutil"
|
||||||
"github.com/databricks/cli/libs/python/pythontest"
|
"github.com/databricks/cli/libs/python/pythontest"
|
||||||
|
"github.com/databricks/cli/libs/testdiff"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,14 +51,14 @@ func testDefaultPython(t *testing.T, pythonVersion string) {
|
||||||
ctx, wt := acc.WorkspaceTest(t)
|
ctx, wt := acc.WorkspaceTest(t)
|
||||||
|
|
||||||
uniqueProjectId := testutil.RandomName("")
|
uniqueProjectId := testutil.RandomName("")
|
||||||
ctx, replacements := testcli.WithReplacementsMap(ctx)
|
ctx, replacements := testdiff.WithReplacementsMap(ctx)
|
||||||
replacements.Set(uniqueProjectId, "$UNIQUE_PRJ")
|
replacements.Set(uniqueProjectId, "$UNIQUE_PRJ")
|
||||||
|
|
||||||
user, err := wt.W.CurrentUser.Me(ctx)
|
user, err := wt.W.CurrentUser.Me(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, user)
|
require.NotNil(t, user)
|
||||||
testcli.PrepareReplacementsUser(t, replacements, *user)
|
testdiff.PrepareReplacementsUser(t, replacements, *user)
|
||||||
testcli.PrepareReplacements(t, replacements, wt.W)
|
testdiff.PrepareReplacements(t, replacements, wt.W)
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
testutil.Chdir(t, tmpDir)
|
testutil.Chdir(t, tmpDir)
|
||||||
|
|
|
@ -3,222 +3,27 @@ package testcli
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/databricks/cli/internal/testutil"
|
"github.com/databricks/cli/internal/testutil"
|
||||||
"github.com/databricks/cli/libs/iamutil"
|
|
||||||
"github.com/databricks/cli/libs/testdiff"
|
"github.com/databricks/cli/libs/testdiff"
|
||||||
"github.com/databricks/databricks-sdk-go"
|
|
||||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var OverwriteMode = os.Getenv("TESTS_OUTPUT") == "OVERWRITE"
|
|
||||||
|
|
||||||
func ReadFile(t testutil.TestingT, ctx context.Context, filename string) string {
|
|
||||||
data, err := os.ReadFile(filename)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
// On CI, on Windows \n in the file somehow end up as \r\n
|
|
||||||
return NormalizeNewlines(string(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
func captureOutput(t testutil.TestingT, ctx context.Context, args []string) string {
|
func captureOutput(t testutil.TestingT, ctx context.Context, args []string) string {
|
||||||
t.Logf("run args: [%s]", strings.Join(args, ", "))
|
t.Logf("run args: [%s]", strings.Join(args, ", "))
|
||||||
r := NewRunner(t, ctx, args...)
|
r := NewRunner(t, ctx, args...)
|
||||||
stdout, stderr, err := r.Run()
|
stdout, stderr, err := r.Run()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
out := stderr.String() + stdout.String()
|
return stderr.String() + stdout.String()
|
||||||
return ReplaceOutput(t, ctx, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteFile(t testutil.TestingT, filename, data string) {
|
|
||||||
t.Logf("Overwriting %s", filename)
|
|
||||||
err := os.WriteFile(filename, []byte(data), 0o644)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertOutput(t testutil.TestingT, ctx context.Context, args []string, expectedPath string) {
|
func AssertOutput(t testutil.TestingT, ctx context.Context, args []string, expectedPath string) {
|
||||||
expected := ReadFile(t, ctx, expectedPath)
|
|
||||||
|
|
||||||
out := captureOutput(t, ctx, args)
|
out := captureOutput(t, ctx, args)
|
||||||
|
testdiff.AssertOutput(t, ctx, out, fmt.Sprintf("Output from %v", args), expectedPath)
|
||||||
if out != expected {
|
|
||||||
actual := fmt.Sprintf("Output from %v", args)
|
|
||||||
testdiff.AssertEqualTexts(t, expectedPath, actual, expected, out)
|
|
||||||
|
|
||||||
if OverwriteMode {
|
|
||||||
WriteFile(t, expectedPath, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertOutputJQ(t testutil.TestingT, ctx context.Context, args []string, expectedPath string, ignorePaths []string) {
|
func AssertOutputJQ(t testutil.TestingT, ctx context.Context, args []string, expectedPath string, ignorePaths []string) {
|
||||||
expected := ReadFile(t, ctx, expectedPath)
|
|
||||||
|
|
||||||
out := captureOutput(t, ctx, args)
|
out := captureOutput(t, ctx, args)
|
||||||
|
testdiff.AssertOutputJQ(t, ctx, out, fmt.Sprintf("Output from %v", args), expectedPath, ignorePaths)
|
||||||
if out != expected {
|
|
||||||
actual := fmt.Sprintf("Output from %v", args)
|
|
||||||
testdiff.AssertEqualJQ(t.(*testing.T), expectedPath, actual, expected, out, ignorePaths)
|
|
||||||
|
|
||||||
if OverwriteMode {
|
|
||||||
WriteFile(t, expectedPath, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
uuidRegex = regexp.MustCompile(`[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}`)
|
|
||||||
numIdRegex = regexp.MustCompile(`[0-9]{3,}`)
|
|
||||||
privatePathRegex = regexp.MustCompile(`(/tmp|/private)(/.*)/([a-zA-Z0-9]+)`)
|
|
||||||
)
|
|
||||||
|
|
||||||
func ReplaceOutput(t testutil.TestingT, ctx context.Context, out string) string {
|
|
||||||
out = NormalizeNewlines(out)
|
|
||||||
replacements := GetReplacementsMap(ctx)
|
|
||||||
if replacements == nil {
|
|
||||||
t.Fatal("WithReplacementsMap was not called")
|
|
||||||
}
|
|
||||||
out = replacements.Replace(out)
|
|
||||||
out = uuidRegex.ReplaceAllString(out, "<UUID>")
|
|
||||||
out = numIdRegex.ReplaceAllString(out, "<NUMID>")
|
|
||||||
out = privatePathRegex.ReplaceAllString(out, "/tmp/.../$3")
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
type key int
|
|
||||||
|
|
||||||
const (
|
|
||||||
replacementsMapKey = key(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
type Replacement struct {
|
|
||||||
Old string
|
|
||||||
New string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReplacementsContext struct {
|
|
||||||
Repls []Replacement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ReplacementsContext) Replace(s string) string {
|
|
||||||
// QQQ Should probably only replace whole words
|
|
||||||
for _, repl := range r.Repls {
|
|
||||||
s = strings.ReplaceAll(s, repl.Old, repl.New)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ReplacementsContext) Set(old, new string) {
|
|
||||||
if old == "" || new == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Repls = append(r.Repls, Replacement{Old: old, New: new})
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithReplacementsMap(ctx context.Context) (context.Context, *ReplacementsContext) {
|
|
||||||
value := ctx.Value(replacementsMapKey)
|
|
||||||
if value != nil {
|
|
||||||
if existingMap, ok := value.(*ReplacementsContext); ok {
|
|
||||||
return ctx, existingMap
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newMap := &ReplacementsContext{}
|
|
||||||
ctx = context.WithValue(ctx, replacementsMapKey, newMap)
|
|
||||||
return ctx, newMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetReplacementsMap(ctx context.Context) *ReplacementsContext {
|
|
||||||
value := ctx.Value(replacementsMapKey)
|
|
||||||
if value != nil {
|
|
||||||
if existingMap, ok := value.(*ReplacementsContext); ok {
|
|
||||||
return existingMap
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepareReplacements(t testutil.TestingT, r *ReplacementsContext, w *databricks.WorkspaceClient) {
|
|
||||||
// in some clouds (gcp) w.Config.Host includes "https://" prefix in others it's really just a host (azure)
|
|
||||||
host := strings.TrimPrefix(strings.TrimPrefix(w.Config.Host, "http://"), "https://")
|
|
||||||
r.Set(host, "$DATABRICKS_HOST")
|
|
||||||
r.Set(w.Config.ClusterID, "$DATABRICKS_CLUSTER_ID")
|
|
||||||
r.Set(w.Config.WarehouseID, "$DATABRICKS_WAREHOUSE_ID")
|
|
||||||
r.Set(w.Config.ServerlessComputeID, "$DATABRICKS_SERVERLESS_COMPUTE_ID")
|
|
||||||
r.Set(w.Config.MetadataServiceURL, "$DATABRICKS_METADATA_SERVICE_URL")
|
|
||||||
r.Set(w.Config.AccountID, "$DATABRICKS_ACCOUNT_ID")
|
|
||||||
r.Set(w.Config.Token, "$DATABRICKS_TOKEN")
|
|
||||||
r.Set(w.Config.Username, "$DATABRICKS_USERNAME")
|
|
||||||
r.Set(w.Config.Password, "$DATABRICKS_PASSWORD")
|
|
||||||
r.Set(w.Config.Profile, "$DATABRICKS_CONFIG_PROFILE")
|
|
||||||
r.Set(w.Config.ConfigFile, "$DATABRICKS_CONFIG_FILE")
|
|
||||||
r.Set(w.Config.GoogleServiceAccount, "$DATABRICKS_GOOGLE_SERVICE_ACCOUNT")
|
|
||||||
r.Set(w.Config.GoogleCredentials, "$GOOGLE_CREDENTIALS")
|
|
||||||
r.Set(w.Config.AzureResourceID, "$DATABRICKS_AZURE_RESOURCE_ID")
|
|
||||||
r.Set(w.Config.AzureClientSecret, "$ARM_CLIENT_SECRET")
|
|
||||||
// r.Set(w.Config.AzureClientID, "$ARM_CLIENT_ID")
|
|
||||||
r.Set(w.Config.AzureClientID, "$USERNAME")
|
|
||||||
r.Set(w.Config.AzureTenantID, "$ARM_TENANT_ID")
|
|
||||||
r.Set(w.Config.ActionsIDTokenRequestURL, "$ACTIONS_ID_TOKEN_REQUEST_URL")
|
|
||||||
r.Set(w.Config.ActionsIDTokenRequestToken, "$ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
|
||||||
r.Set(w.Config.AzureEnvironment, "$ARM_ENVIRONMENT")
|
|
||||||
r.Set(w.Config.ClientID, "$DATABRICKS_CLIENT_ID")
|
|
||||||
r.Set(w.Config.ClientSecret, "$DATABRICKS_CLIENT_SECRET")
|
|
||||||
r.Set(w.Config.DatabricksCliPath, "$DATABRICKS_CLI_PATH")
|
|
||||||
// This is set to words like "path" that happen too frequently
|
|
||||||
// r.Set(w.Config.AuthType, "$DATABRICKS_AUTH_TYPE")
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepareReplacementsUser(t testutil.TestingT, r *ReplacementsContext, u iam.User) {
|
|
||||||
// There could be exact matches or overlap between different name fields, so sort them by length
|
|
||||||
// to ensure we match the largest one first and map them all to the same token
|
|
||||||
names := []string{
|
|
||||||
u.DisplayName,
|
|
||||||
u.UserName,
|
|
||||||
iamutil.GetShortUserName(&u),
|
|
||||||
u.Name.FamilyName,
|
|
||||||
u.Name.GivenName,
|
|
||||||
}
|
|
||||||
if u.Name != nil {
|
|
||||||
names = append(names, u.Name.FamilyName)
|
|
||||||
names = append(names, u.Name.GivenName)
|
|
||||||
}
|
|
||||||
for _, val := range u.Emails {
|
|
||||||
names = append(names, val.Value)
|
|
||||||
}
|
|
||||||
stableSortReverseLength(names)
|
|
||||||
|
|
||||||
for _, name := range names {
|
|
||||||
r.Set(name, "$USERNAME")
|
|
||||||
}
|
|
||||||
|
|
||||||
for ind, val := range u.Groups {
|
|
||||||
r.Set(val.Value, fmt.Sprintf("$USER.Groups[%d]", ind))
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Set(u.Id, "$USER.Id")
|
|
||||||
|
|
||||||
for ind, val := range u.Roles {
|
|
||||||
r.Set(val.Value, fmt.Sprintf("$USER.Roles[%d]", ind))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func stableSortReverseLength(strs []string) {
|
|
||||||
slices.SortStableFunc(strs, func(a, b string) int {
|
|
||||||
return len(b) - len(a)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func NormalizeNewlines(input string) string {
|
|
||||||
output := strings.ReplaceAll(input, "\r\n", "\n")
|
|
||||||
return strings.ReplaceAll(output, "\r", "\n")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,212 @@
|
||||||
|
package testdiff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/databricks/cli/internal/testutil"
|
||||||
|
"github.com/databricks/cli/libs/iamutil"
|
||||||
|
"github.com/databricks/databricks-sdk-go"
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var OverwriteMode = os.Getenv("TESTS_OUTPUT") == "OVERWRITE"
|
||||||
|
|
||||||
|
func ReadFile(t testutil.TestingT, ctx context.Context, filename string) string {
|
||||||
|
data, err := os.ReadFile(filename)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// On CI, on Windows \n in the file somehow end up as \r\n
|
||||||
|
return NormalizeNewlines(string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteFile(t testutil.TestingT, filename, data string) {
|
||||||
|
t.Logf("Overwriting %s", filename)
|
||||||
|
err := os.WriteFile(filename, []byte(data), 0o644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertOutput(t testutil.TestingT, ctx context.Context, out, outTitle, expectedPath string) {
|
||||||
|
expected := ReadFile(t, ctx, expectedPath)
|
||||||
|
|
||||||
|
out = ReplaceOutput(t, ctx, out)
|
||||||
|
|
||||||
|
if out != expected {
|
||||||
|
AssertEqualTexts(t, expectedPath, outTitle, expected, out)
|
||||||
|
|
||||||
|
if OverwriteMode {
|
||||||
|
WriteFile(t, expectedPath, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertOutputJQ(t testutil.TestingT, ctx context.Context, out, outTitle, expectedPath string, ignorePaths []string) {
|
||||||
|
expected := ReadFile(t, ctx, expectedPath)
|
||||||
|
|
||||||
|
out = ReplaceOutput(t, ctx, out)
|
||||||
|
|
||||||
|
if out != expected {
|
||||||
|
AssertEqualJQ(t.(*testing.T), expectedPath, outTitle, expected, out, ignorePaths)
|
||||||
|
|
||||||
|
if OverwriteMode {
|
||||||
|
WriteFile(t, expectedPath, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
uuidRegex = regexp.MustCompile(`[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}`)
|
||||||
|
numIdRegex = regexp.MustCompile(`[0-9]{3,}`)
|
||||||
|
privatePathRegex = regexp.MustCompile(`(/tmp|/private)(/.*)/([a-zA-Z0-9]+)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func ReplaceOutput(t testutil.TestingT, ctx context.Context, out string) string {
|
||||||
|
out = NormalizeNewlines(out)
|
||||||
|
replacements := GetReplacementsMap(ctx)
|
||||||
|
if replacements == nil {
|
||||||
|
t.Fatal("WithReplacementsMap was not called")
|
||||||
|
}
|
||||||
|
out = replacements.Replace(out)
|
||||||
|
out = uuidRegex.ReplaceAllString(out, "<UUID>")
|
||||||
|
out = numIdRegex.ReplaceAllString(out, "<NUMID>")
|
||||||
|
out = privatePathRegex.ReplaceAllString(out, "/tmp/.../$3")
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
type key int
|
||||||
|
|
||||||
|
const (
|
||||||
|
replacementsMapKey = key(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Replacement struct {
|
||||||
|
Old string
|
||||||
|
New string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplacementsContext struct {
|
||||||
|
Repls []Replacement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReplacementsContext) Replace(s string) string {
|
||||||
|
// QQQ Should probably only replace whole words
|
||||||
|
for _, repl := range r.Repls {
|
||||||
|
s = strings.ReplaceAll(s, repl.Old, repl.New)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReplacementsContext) Set(old, new string) {
|
||||||
|
if old == "" || new == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Repls = append(r.Repls, Replacement{Old: old, New: new})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithReplacementsMap(ctx context.Context) (context.Context, *ReplacementsContext) {
|
||||||
|
value := ctx.Value(replacementsMapKey)
|
||||||
|
if value != nil {
|
||||||
|
if existingMap, ok := value.(*ReplacementsContext); ok {
|
||||||
|
return ctx, existingMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newMap := &ReplacementsContext{}
|
||||||
|
ctx = context.WithValue(ctx, replacementsMapKey, newMap)
|
||||||
|
return ctx, newMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetReplacementsMap(ctx context.Context) *ReplacementsContext {
|
||||||
|
value := ctx.Value(replacementsMapKey)
|
||||||
|
if value != nil {
|
||||||
|
if existingMap, ok := value.(*ReplacementsContext); ok {
|
||||||
|
return existingMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareReplacements(t testutil.TestingT, r *ReplacementsContext, w *databricks.WorkspaceClient) {
|
||||||
|
// in some clouds (gcp) w.Config.Host includes "https://" prefix in others it's really just a host (azure)
|
||||||
|
host := strings.TrimPrefix(strings.TrimPrefix(w.Config.Host, "http://"), "https://")
|
||||||
|
r.Set(host, "$DATABRICKS_HOST")
|
||||||
|
r.Set(w.Config.ClusterID, "$DATABRICKS_CLUSTER_ID")
|
||||||
|
r.Set(w.Config.WarehouseID, "$DATABRICKS_WAREHOUSE_ID")
|
||||||
|
r.Set(w.Config.ServerlessComputeID, "$DATABRICKS_SERVERLESS_COMPUTE_ID")
|
||||||
|
r.Set(w.Config.MetadataServiceURL, "$DATABRICKS_METADATA_SERVICE_URL")
|
||||||
|
r.Set(w.Config.AccountID, "$DATABRICKS_ACCOUNT_ID")
|
||||||
|
r.Set(w.Config.Token, "$DATABRICKS_TOKEN")
|
||||||
|
r.Set(w.Config.Username, "$DATABRICKS_USERNAME")
|
||||||
|
r.Set(w.Config.Password, "$DATABRICKS_PASSWORD")
|
||||||
|
r.Set(w.Config.Profile, "$DATABRICKS_CONFIG_PROFILE")
|
||||||
|
r.Set(w.Config.ConfigFile, "$DATABRICKS_CONFIG_FILE")
|
||||||
|
r.Set(w.Config.GoogleServiceAccount, "$DATABRICKS_GOOGLE_SERVICE_ACCOUNT")
|
||||||
|
r.Set(w.Config.GoogleCredentials, "$GOOGLE_CREDENTIALS")
|
||||||
|
r.Set(w.Config.AzureResourceID, "$DATABRICKS_AZURE_RESOURCE_ID")
|
||||||
|
r.Set(w.Config.AzureClientSecret, "$ARM_CLIENT_SECRET")
|
||||||
|
// r.Set(w.Config.AzureClientID, "$ARM_CLIENT_ID")
|
||||||
|
r.Set(w.Config.AzureClientID, "$USERNAME")
|
||||||
|
r.Set(w.Config.AzureTenantID, "$ARM_TENANT_ID")
|
||||||
|
r.Set(w.Config.ActionsIDTokenRequestURL, "$ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||||
|
r.Set(w.Config.ActionsIDTokenRequestToken, "$ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||||
|
r.Set(w.Config.AzureEnvironment, "$ARM_ENVIRONMENT")
|
||||||
|
r.Set(w.Config.ClientID, "$DATABRICKS_CLIENT_ID")
|
||||||
|
r.Set(w.Config.ClientSecret, "$DATABRICKS_CLIENT_SECRET")
|
||||||
|
r.Set(w.Config.DatabricksCliPath, "$DATABRICKS_CLI_PATH")
|
||||||
|
// This is set to words like "path" that happen too frequently
|
||||||
|
// r.Set(w.Config.AuthType, "$DATABRICKS_AUTH_TYPE")
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareReplacementsUser(t testutil.TestingT, r *ReplacementsContext, u iam.User) {
|
||||||
|
// There could be exact matches or overlap between different name fields, so sort them by length
|
||||||
|
// to ensure we match the largest one first and map them all to the same token
|
||||||
|
names := []string{
|
||||||
|
u.DisplayName,
|
||||||
|
u.UserName,
|
||||||
|
iamutil.GetShortUserName(&u),
|
||||||
|
u.Name.FamilyName,
|
||||||
|
u.Name.GivenName,
|
||||||
|
}
|
||||||
|
if u.Name != nil {
|
||||||
|
names = append(names, u.Name.FamilyName)
|
||||||
|
names = append(names, u.Name.GivenName)
|
||||||
|
}
|
||||||
|
for _, val := range u.Emails {
|
||||||
|
names = append(names, val.Value)
|
||||||
|
}
|
||||||
|
stableSortReverseLength(names)
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
r.Set(name, "$USERNAME")
|
||||||
|
}
|
||||||
|
|
||||||
|
for ind, val := range u.Groups {
|
||||||
|
r.Set(val.Value, fmt.Sprintf("$USER.Groups[%d]", ind))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Set(u.Id, "$USER.Id")
|
||||||
|
|
||||||
|
for ind, val := range u.Roles {
|
||||||
|
r.Set(val.Value, fmt.Sprintf("$USER.Roles[%d]", ind))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stableSortReverseLength(strs []string) {
|
||||||
|
slices.SortStableFunc(strs, func(a, b string) int {
|
||||||
|
return len(b) - len(a)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeNewlines(input string) string {
|
||||||
|
output := strings.ReplaceAll(input, "\r\n", "\n")
|
||||||
|
return strings.ReplaceAll(output, "\r", "\n")
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package testcli
|
package testdiff
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
|
@ -0,0 +1,9 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
# With golangci-lint, if there are any compliation issues, then formatters' autofix won't be applied.
|
||||||
|
# https://github.com/golangci/golangci-lint/issues/5257
|
||||||
|
# However, running goimports first alone will actually fix some of the compilation issues.
|
||||||
|
# Fixing formatting is also reasonable thing to do.
|
||||||
|
# For this reason, this script runs golangci-lint in two stages:
|
||||||
|
golangci-lint run --enable-only="gofmt,gofumpt,goimports" --fix $@
|
||||||
|
exec golangci-lint run --fix $@
|
Loading…
Reference in New Issue