diff --git a/acceptance/acceptance_test.go b/acceptance/acceptance_test.go index fc486e2cc..e6eb380a4 100644 --- a/acceptance/acceptance_test.go +++ b/acceptance/acceptance_test.go @@ -88,7 +88,7 @@ func TestAccept(t *testing.T) { require.NoError(t, err) require.NotNil(t, user) testdiff.PrepareReplacementsUser(t, &repls, *user) - testdiff.PrepareReplacements(t, &repls, workspaceClient) + testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient) testDirs := getTests(t) require.NotEmpty(t, testDirs) diff --git a/integration/bundle/init_default_python_test.go b/integration/bundle/init_default_python_test.go index c93e6b50b..931660032 100644 --- a/integration/bundle/init_default_python_test.go +++ b/integration/bundle/init_default_python_test.go @@ -58,7 +58,10 @@ func testDefaultPython(t *testing.T, pythonVersion string) { require.NoError(t, err) require.NotNil(t, user) testdiff.PrepareReplacementsUser(t, replacements, *user) - testdiff.PrepareReplacements(t, replacements, wt.W) + testdiff.PrepareReplacementsWorkspaceClient(t, replacements, wt.W) + testdiff.PrepareReplacementsUUID(t, replacements) + testdiff.PrepareReplacementsNumber(t, replacements) + testdiff.PrepareReplacementsTemporaryDirectory(t, replacements) tmpDir := t.TempDir() testutil.Chdir(t, tmpDir) diff --git a/libs/testdiff/context.go b/libs/testdiff/context.go new file mode 100644 index 000000000..7b6f5ff88 --- /dev/null +++ b/libs/testdiff/context.go @@ -0,0 +1,34 @@ +package testdiff + +import ( + "context" +) + +type key int + +const ( + replacementsMapKey = key(1) +) + +func WithReplacementsMap(ctx context.Context) (context.Context, *ReplacementsContext) { + value := ctx.Value(replacementsMapKey) + if value != nil { + if existingMap, ok := value.(*ReplacementsContext); ok { + return ctx, existingMap + } + } + + newMap := &ReplacementsContext{} + ctx = context.WithValue(ctx, replacementsMapKey, newMap) + return ctx, newMap +} + +func GetReplacementsMap(ctx context.Context) *ReplacementsContext { + value := ctx.Value(replacementsMapKey) + if value != nil { + if existingMap, ok := value.(*ReplacementsContext); ok { + return existingMap + } + } + return nil +} diff --git a/libs/testdiff/context_test.go b/libs/testdiff/context_test.go new file mode 100644 index 000000000..5a0191009 --- /dev/null +++ b/libs/testdiff/context_test.go @@ -0,0 +1,30 @@ +package testdiff + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetReplacementsMap_Nil(t *testing.T) { + ctx := context.Background() + repls := GetReplacementsMap(ctx) + assert.Nil(t, repls) +} + +func TestGetReplacementsMap_NotNil(t *testing.T) { + ctx := context.Background() + ctx, _ = WithReplacementsMap(ctx) + repls := GetReplacementsMap(ctx) + assert.NotNil(t, repls) +} + +func TestWithReplacementsMap_UseExisting(t *testing.T) { + ctx := context.Background() + ctx, r1 := WithReplacementsMap(ctx) + ctx, r2 := WithReplacementsMap(ctx) + repls := GetReplacementsMap(ctx) + assert.Equal(t, r1, repls) + assert.Equal(t, r2, repls) +} diff --git a/libs/testdiff/golden.go b/libs/testdiff/golden.go index dd07df408..c1c51b6c5 100644 --- a/libs/testdiff/golden.go +++ b/libs/testdiff/golden.go @@ -2,25 +2,15 @@ package testdiff import ( "context" - "encoding/json" "flag" - "fmt" "os" - "regexp" "strings" "testing" "github.com/databricks/cli/internal/testutil" - "github.com/databricks/cli/libs/iamutil" - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/service/iam" "github.com/stretchr/testify/assert" ) -const ( - testerName = "$USERNAME" -) - var OverwriteMode = false func init() { @@ -75,12 +65,6 @@ func AssertOutputJQ(t testutil.TestingT, ctx context.Context, out, outTitle, exp } } -var ( - uuidRegex = regexp.MustCompile(`[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}`) - numIdRegex = regexp.MustCompile(`[0-9]{3,}`) - privatePathRegex = regexp.MustCompile(`(/tmp|/private)(/.*)/([a-zA-Z0-9]+)`) -) - func ReplaceOutput(t testutil.TestingT, ctx context.Context, out string) string { t.Helper() out = NormalizeNewlines(out) @@ -88,139 +72,7 @@ func ReplaceOutput(t testutil.TestingT, ctx context.Context, out string) string if replacements == nil { t.Fatal("WithReplacementsMap was not called") } - out = replacements.Replace(out) - out = uuidRegex.ReplaceAllString(out, "") - out = numIdRegex.ReplaceAllString(out, "") - out = privatePathRegex.ReplaceAllString(out, "/tmp/.../$3") - - return out -} - -type key int - -const ( - replacementsMapKey = key(1) -) - -type Replacement struct { - Old string - New string -} - -type ReplacementsContext struct { - Repls []Replacement -} - -func (r *ReplacementsContext) Replace(s string) string { - // QQQ Should probably only replace whole words - for _, repl := range r.Repls { - s = strings.ReplaceAll(s, repl.Old, repl.New) - } - return s -} - -func (r *ReplacementsContext) Set(old, new string) { - if old == "" || new == "" { - return - } - - // Always include both verbatim and json version of replacement. - // This helps when the string in question contains \ or other chars that need to be quoted. - // In that case we cannot rely that json(old) == '"{old}"' and need to add it explicitly. - - encodedNew, err := json.Marshal(new) - if err == nil { - encodedOld, err := json.Marshal(old) - if err == nil { - r.Repls = append(r.Repls, Replacement{Old: string(encodedOld), New: string(encodedNew)}) - } - } - - r.Repls = append(r.Repls, Replacement{Old: old, New: new}) -} - -func WithReplacementsMap(ctx context.Context) (context.Context, *ReplacementsContext) { - value := ctx.Value(replacementsMapKey) - if value != nil { - if existingMap, ok := value.(*ReplacementsContext); ok { - return ctx, existingMap - } - } - - newMap := &ReplacementsContext{} - ctx = context.WithValue(ctx, replacementsMapKey, newMap) - return ctx, newMap -} - -func GetReplacementsMap(ctx context.Context) *ReplacementsContext { - value := ctx.Value(replacementsMapKey) - if value != nil { - if existingMap, ok := value.(*ReplacementsContext); ok { - return existingMap - } - } - return nil -} - -func PrepareReplacements(t testutil.TestingT, r *ReplacementsContext, w *databricks.WorkspaceClient) { - t.Helper() - // in some clouds (gcp) w.Config.Host includes "https://" prefix in others it's really just a host (azure) - host := strings.TrimPrefix(strings.TrimPrefix(w.Config.Host, "http://"), "https://") - r.Set(host, "$DATABRICKS_HOST") - r.Set(w.Config.ClusterID, "$DATABRICKS_CLUSTER_ID") - r.Set(w.Config.WarehouseID, "$DATABRICKS_WAREHOUSE_ID") - r.Set(w.Config.ServerlessComputeID, "$DATABRICKS_SERVERLESS_COMPUTE_ID") - r.Set(w.Config.MetadataServiceURL, "$DATABRICKS_METADATA_SERVICE_URL") - r.Set(w.Config.AccountID, "$DATABRICKS_ACCOUNT_ID") - r.Set(w.Config.Token, "$DATABRICKS_TOKEN") - r.Set(w.Config.Username, "$DATABRICKS_USERNAME") - r.Set(w.Config.Password, "$DATABRICKS_PASSWORD") - r.Set(w.Config.Profile, "$DATABRICKS_CONFIG_PROFILE") - r.Set(w.Config.ConfigFile, "$DATABRICKS_CONFIG_FILE") - r.Set(w.Config.GoogleServiceAccount, "$DATABRICKS_GOOGLE_SERVICE_ACCOUNT") - r.Set(w.Config.GoogleCredentials, "$GOOGLE_CREDENTIALS") - r.Set(w.Config.AzureResourceID, "$DATABRICKS_AZURE_RESOURCE_ID") - r.Set(w.Config.AzureClientSecret, "$ARM_CLIENT_SECRET") - // r.Set(w.Config.AzureClientID, "$ARM_CLIENT_ID") - r.Set(w.Config.AzureClientID, testerName) - r.Set(w.Config.AzureTenantID, "$ARM_TENANT_ID") - r.Set(w.Config.ActionsIDTokenRequestURL, "$ACTIONS_ID_TOKEN_REQUEST_URL") - r.Set(w.Config.ActionsIDTokenRequestToken, "$ACTIONS_ID_TOKEN_REQUEST_TOKEN") - r.Set(w.Config.AzureEnvironment, "$ARM_ENVIRONMENT") - r.Set(w.Config.ClientID, "$DATABRICKS_CLIENT_ID") - r.Set(w.Config.ClientSecret, "$DATABRICKS_CLIENT_SECRET") - r.Set(w.Config.DatabricksCliPath, "$DATABRICKS_CLI_PATH") - // This is set to words like "path" that happen too frequently - // r.Set(w.Config.AuthType, "$DATABRICKS_AUTH_TYPE") -} - -func PrepareReplacementsUser(t testutil.TestingT, r *ReplacementsContext, u iam.User) { - t.Helper() - // There could be exact matches or overlap between different name fields, so sort them by length - // to ensure we match the largest one first and map them all to the same token - - r.Set(u.UserName, testerName) - r.Set(u.DisplayName, testerName) - if u.Name != nil { - r.Set(u.Name.FamilyName, testerName) - r.Set(u.Name.GivenName, testerName) - } - - for _, val := range u.Emails { - r.Set(val.Value, testerName) - } - - r.Set(iamutil.GetShortUserName(&u), testerName) - - for ind, val := range u.Groups { - r.Set(val.Value, fmt.Sprintf("$USER.Groups[%d]", ind)) - } - - r.Set(u.Id, "$USER.Id") - - for ind, val := range u.Roles { - r.Set(val.Value, fmt.Sprintf("$USER.Roles[%d]", ind)) - } + return replacements.Replace(out) } func NormalizeNewlines(input string) string { diff --git a/libs/testdiff/replacement.go b/libs/testdiff/replacement.go new file mode 100644 index 000000000..fee6d80fd --- /dev/null +++ b/libs/testdiff/replacement.go @@ -0,0 +1,153 @@ +package testdiff + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/iamutil" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/iam" +) + +const ( + testerName = "$USERNAME" +) + +var ( + uuidRegex = regexp.MustCompile(`[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}`) + numIdRegex = regexp.MustCompile(`[0-9]{3,}`) + privatePathRegex = regexp.MustCompile(`(/tmp|/private)(/.*)/([a-zA-Z0-9]+)`) +) + +type Replacement struct { + Old *regexp.Regexp + New string +} + +type ReplacementsContext struct { + Repls []Replacement +} + +func (r *ReplacementsContext) Replace(s string) string { + // QQQ Should probably only replace whole words + for _, repl := range r.Repls { + s = repl.Old.ReplaceAllString(s, repl.New) + } + return s +} + +func (r *ReplacementsContext) append(pattern *regexp.Regexp, replacement string) { + r.Repls = append(r.Repls, Replacement{ + Old: pattern, + New: replacement, + }) +} + +func (r *ReplacementsContext) appendLiteral(old, new string) { + r.append( + // Transform the input strings such that they can be used as literal strings in regular expressions. + regexp.MustCompile(regexp.QuoteMeta(old)), + // Transform the replacement string such that `$` is interpreted as a literal dollar sign. + // For more information about how the replacement string is used, see [regexp.Regexp.Expand]. + strings.ReplaceAll(new, `$`, `$$`), + ) +} + +func (r *ReplacementsContext) Set(old, new string) { + if old == "" || new == "" { + return + } + + // Always include both verbatim and json version of replacement. + // This helps when the string in question contains \ or other chars that need to be quoted. + // In that case we cannot rely that json(old) == '"{old}"' and need to add it explicitly. + + encodedNew, err := json.Marshal(new) + if err == nil { + encodedOld, err := json.Marshal(old) + if err == nil { + r.appendLiteral(string(encodedOld), string(encodedNew)) + } + } + + r.appendLiteral(old, new) +} + +func PrepareReplacementsWorkspaceClient(t testutil.TestingT, r *ReplacementsContext, w *databricks.WorkspaceClient) { + t.Helper() + // in some clouds (gcp) w.Config.Host includes "https://" prefix in others it's really just a host (azure) + host := strings.TrimPrefix(strings.TrimPrefix(w.Config.Host, "http://"), "https://") + r.Set(host, "$DATABRICKS_HOST") + r.Set(w.Config.ClusterID, "$DATABRICKS_CLUSTER_ID") + r.Set(w.Config.WarehouseID, "$DATABRICKS_WAREHOUSE_ID") + r.Set(w.Config.ServerlessComputeID, "$DATABRICKS_SERVERLESS_COMPUTE_ID") + r.Set(w.Config.MetadataServiceURL, "$DATABRICKS_METADATA_SERVICE_URL") + r.Set(w.Config.AccountID, "$DATABRICKS_ACCOUNT_ID") + r.Set(w.Config.Token, "$DATABRICKS_TOKEN") + r.Set(w.Config.Username, "$DATABRICKS_USERNAME") + r.Set(w.Config.Password, "$DATABRICKS_PASSWORD") + r.Set(w.Config.Profile, "$DATABRICKS_CONFIG_PROFILE") + r.Set(w.Config.ConfigFile, "$DATABRICKS_CONFIG_FILE") + r.Set(w.Config.GoogleServiceAccount, "$DATABRICKS_GOOGLE_SERVICE_ACCOUNT") + r.Set(w.Config.GoogleCredentials, "$GOOGLE_CREDENTIALS") + r.Set(w.Config.AzureResourceID, "$DATABRICKS_AZURE_RESOURCE_ID") + r.Set(w.Config.AzureClientSecret, "$ARM_CLIENT_SECRET") + // r.Set(w.Config.AzureClientID, "$ARM_CLIENT_ID") + r.Set(w.Config.AzureClientID, testerName) + r.Set(w.Config.AzureTenantID, "$ARM_TENANT_ID") + r.Set(w.Config.ActionsIDTokenRequestURL, "$ACTIONS_ID_TOKEN_REQUEST_URL") + r.Set(w.Config.ActionsIDTokenRequestToken, "$ACTIONS_ID_TOKEN_REQUEST_TOKEN") + r.Set(w.Config.AzureEnvironment, "$ARM_ENVIRONMENT") + r.Set(w.Config.ClientID, "$DATABRICKS_CLIENT_ID") + r.Set(w.Config.ClientSecret, "$DATABRICKS_CLIENT_SECRET") + r.Set(w.Config.DatabricksCliPath, "$DATABRICKS_CLI_PATH") + // This is set to words like "path" that happen too frequently + // r.Set(w.Config.AuthType, "$DATABRICKS_AUTH_TYPE") +} + +func PrepareReplacementsUser(t testutil.TestingT, r *ReplacementsContext, u iam.User) { + t.Helper() + // There could be exact matches or overlap between different name fields, so sort them by length + // to ensure we match the largest one first and map them all to the same token + + r.Set(u.UserName, testerName) + r.Set(u.DisplayName, testerName) + if u.Name != nil { + r.Set(u.Name.FamilyName, testerName) + r.Set(u.Name.GivenName, testerName) + } + + for _, val := range u.Emails { + r.Set(val.Value, testerName) + } + + r.Set(iamutil.GetShortUserName(&u), testerName) + + for ind, val := range u.Groups { + r.Set(val.Value, fmt.Sprintf("$USER.Groups[%d]", ind)) + } + + r.Set(u.Id, "$USER.Id") + + for ind, val := range u.Roles { + r.Set(val.Value, fmt.Sprintf("$USER.Roles[%d]", ind)) + } +} + +func PrepareReplacementsUUID(t testutil.TestingT, r *ReplacementsContext) { + t.Helper() + r.append(uuidRegex, "") +} + +func PrepareReplacementsNumber(t testutil.TestingT, r *ReplacementsContext) { + t.Helper() + r.append(numIdRegex, "") +} + +func PrepareReplacementsTemporaryDirectory(t testutil.TestingT, r *ReplacementsContext) { + t.Helper() + r.append(privatePathRegex, "/tmp/.../$3") +} diff --git a/libs/testdiff/replacement_test.go b/libs/testdiff/replacement_test.go new file mode 100644 index 000000000..de247c03e --- /dev/null +++ b/libs/testdiff/replacement_test.go @@ -0,0 +1,46 @@ +package testdiff + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReplacement_Literal(t *testing.T) { + var repls ReplacementsContext + + repls.Set(`foobar`, `[replacement]`) + assert.Equal(t, `[replacement]`, repls.Replace(`foobar`)) +} + +func TestReplacement_Encoded(t *testing.T) { + var repls ReplacementsContext + + repls.Set(`foo"bar`, `[replacement]`) + assert.Equal(t, `"[replacement]"`, repls.Replace(`"foo\"bar"`)) +} + +func TestReplacement_UUID(t *testing.T) { + var repls ReplacementsContext + + PrepareReplacementsUUID(t, &repls) + + assert.Equal(t, "", repls.Replace("123e4567-e89b-12d3-a456-426614174000")) +} + +func TestReplacement_Number(t *testing.T) { + var repls ReplacementsContext + + PrepareReplacementsNumber(t, &repls) + + assert.Equal(t, "12", repls.Replace("12")) + assert.Equal(t, "", repls.Replace("123")) +} + +func TestReplacement_TemporaryDirectory(t *testing.T) { + var repls ReplacementsContext + + PrepareReplacementsTemporaryDirectory(t, &repls) + + assert.Equal(t, "/tmp/.../tail", repls.Replace("/tmp/foo/bar/qux/tail")) +}