From cb3ad737f10979e0d683d61d6924bdf58596513c Mon Sep 17 00:00:00 2001 From: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com> Date: Thu, 1 Feb 2024 22:16:07 +0530 Subject: [PATCH] Add short_name helper function to bundle init templates (#1167) ## Changes Adds the short_name helper function. short_name is useful when templates do not want to print the full userName (typically email or service principal application-id) of the current user. ## Tests Integration test. Also adds integration tests for other helper functions that interact with the Databricks API. --- .../config/mutator/populate_current_user.go | 13 +--- .../mutator/populate_current_user_test.go | 67 ----------------- internal/init_test.go | 73 +++++++++++++++++++ internal/workspace_test.go | 1 - libs/auth/user.go | 15 ++++ libs/auth/user_test.go | 72 ++++++++++++++++++ libs/template/helpers.go | 10 +++ 7 files changed, 172 insertions(+), 79 deletions(-) create mode 100644 libs/auth/user.go create mode 100644 libs/auth/user_test.go diff --git a/bundle/config/mutator/populate_current_user.go b/bundle/config/mutator/populate_current_user.go index 60587578..a604cb90 100644 --- a/bundle/config/mutator/populate_current_user.go +++ b/bundle/config/mutator/populate_current_user.go @@ -2,12 +2,11 @@ package mutator import ( "context" - "strings" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/tags" - "github.com/databricks/cli/libs/textutil" ) type populateCurrentUser struct{} @@ -33,7 +32,7 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error } b.Config.Workspace.CurrentUser = &config.User{ - ShortName: getShortUserName(me.UserName), + ShortName: auth.GetShortUserName(me.UserName), User: me, } @@ -42,11 +41,3 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error return nil } - -// Get a short-form username, based on the user's primary email address. -// We leave the full range of unicode letters in tact, but remove all "special" characters, -// including dots, which are not supported in e.g. experiment names. -func getShortUserName(emailAddress string) string { - local, _, _ := strings.Cut(emailAddress, "@") - return textutil.NormalizeString(local) -} diff --git a/bundle/config/mutator/populate_current_user_test.go b/bundle/config/mutator/populate_current_user_test.go index bbb65e07..1475055d 100644 --- a/bundle/config/mutator/populate_current_user_test.go +++ b/bundle/config/mutator/populate_current_user_test.go @@ -2,75 +2,8 @@ package mutator import ( "testing" - - "github.com/stretchr/testify/assert" ) func TestPopulateCurrentUser(t *testing.T) { // We need to implement workspace client mocking to implement this test. } - -func TestGetShortUserName(t *testing.T) { - tests := []struct { - name string - email string - expected string - }{ - { - email: "test.user.1234@example.com", - expected: "test_user_1234", - }, - { - email: "tést.üser@example.com", - expected: "tést_üser", - }, - { - email: "test$.user@example.com", - expected: "test__user", - }, - { - email: `jöhn.dœ@domain.com`, // Using non-ASCII characters. - expected: "jöhn_dœ", - }, - { - email: `first+tag@email.com`, // The plus (+) sign is used for "sub-addressing" in some email services. - expected: "first_tag", - }, - { - email: `email@sub.domain.com`, // Using a sub-domain. - expected: "email", - }, - { - email: `"_quoted"@domain.com`, // Quoted strings can be part of the local-part. - expected: "__quoted_", - }, - { - email: `name-o'mally@website.org`, // Single quote in the local-part. - expected: "name_o_mally", - }, - { - email: `user%domain@external.com`, // Percent sign can be used for email routing in legacy systems. - expected: "user_domain", - }, - { - email: `long.name.with.dots@domain.net`, // Multiple dots in the local-part. - expected: "long_name_with_dots", - }, - { - email: `me&you@together.com`, // Using an ampersand (&) in the local-part. - expected: "me_you", - }, - { - email: `user!def!xyz@domain.org`, // The exclamation mark can be valid in some legacy systems. - expected: "user_def_xyz", - }, - { - email: `admin@ιντερνετ.com`, // Domain in non-ASCII characters (IDN or Internationalized Domain Name). - expected: "admin", - }, - } - - for _, tt := range tests { - assert.Equal(t, tt.expected, getShortUserName(tt.email)) - } -} diff --git a/internal/init_test.go b/internal/init_test.go index a2eda983..c4c3d6d8 100644 --- a/internal/init_test.go +++ b/internal/init_test.go @@ -1,9 +1,16 @@ package internal import ( + "context" + "os" + "path/filepath" + "strconv" "testing" + "github.com/databricks/cli/libs/auth" + "github.com/databricks/databricks-sdk-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAccBundleInitErrorOnUnknownFields(t *testing.T) { @@ -13,3 +20,69 @@ func TestAccBundleInitErrorOnUnknownFields(t *testing.T) { _, _, err := RequireErrorRun(t, "bundle", "init", "./testdata/init/field-does-not-exist", "--output-dir", tmpDir) assert.EqualError(t, err, "failed to compute file content for bar.tmpl. variable \"does_not_exist\" not defined") } + +func TestAccBundleInitHelpers(t *testing.T) { + env := GetEnvOrSkipTest(t, "CLOUD_ENV") + t.Log(env) + + w, err := databricks.NewWorkspaceClient(&databricks.Config{}) + require.NoError(t, err) + + me, err := w.CurrentUser.Me(context.Background()) + require.NoError(t, err) + + var smallestNode string + switch env { + case "azure": + smallestNode = "Standard_D3_v2" + case "gcp": + smallestNode = "n1-standard-4" + default: + smallestNode = "i3.xlarge" + } + + tests := []struct { + funcName string + expected string + }{ + { + funcName: "{{short_name}}", + expected: auth.GetShortUserName(me.UserName), + }, + { + funcName: "{{user_name}}", + expected: me.UserName, + }, + { + funcName: "{{workspace_host}}", + expected: w.Config.Host, + }, + { + funcName: "{{is_service_principal}}", + expected: strconv.FormatBool(auth.IsServicePrincipal(me.Id)), + }, + { + funcName: "{{smallest_node_type}}", + expected: smallestNode, + }, + } + + for _, test := range tests { + // Setup template to test the helper function. + tmpDir := t.TempDir() + tmpDir2 := t.TempDir() + + err := os.Mkdir(filepath.Join(tmpDir, "template"), 0755) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "template", "foo.txt.tmpl"), []byte(test.funcName), 0644) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "databricks_template_schema.json"), []byte("{}"), 0644) + require.NoError(t, err) + + // Run bundle init. + RequireSuccessfulRun(t, "bundle", "init", tmpDir, "--output-dir", tmpDir2) + + // Assert that the helper function was correctly computed. + assertLocalFileContents(t, filepath.Join(tmpDir2, "foo.txt"), test.expected) + } +} diff --git a/internal/workspace_test.go b/internal/workspace_test.go index a6e641b6..16467739 100644 --- a/internal/workspace_test.go +++ b/internal/workspace_test.go @@ -81,7 +81,6 @@ func setupWorkspaceImportExportTest(t *testing.T) (context.Context, filer.Filer, return ctx, f, tmpdir } -// TODO: add tests for the progress event output logs: https://github.com/databricks/cli/issues/447 func assertLocalFileContents(t *testing.T, path string, content string) { require.FileExists(t, path) b, err := os.ReadFile(path) diff --git a/libs/auth/user.go b/libs/auth/user.go new file mode 100644 index 00000000..8eaa8763 --- /dev/null +++ b/libs/auth/user.go @@ -0,0 +1,15 @@ +package auth + +import ( + "strings" + + "github.com/databricks/cli/libs/textutil" +) + +// Get a short-form username, based on the user's primary email address. +// We leave the full range of unicode letters in tact, but remove all "special" characters, +// including dots, which are not supported in e.g. experiment names. +func GetShortUserName(emailAddress string) string { + local, _, _ := strings.Cut(emailAddress, "@") + return textutil.NormalizeString(local) +} diff --git a/libs/auth/user_test.go b/libs/auth/user_test.go new file mode 100644 index 00000000..eb579fc9 --- /dev/null +++ b/libs/auth/user_test.go @@ -0,0 +1,72 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetShortUserName(t *testing.T) { + tests := []struct { + name string + email string + expected string + }{ + { + email: "test.user.1234@example.com", + expected: "test_user_1234", + }, + { + email: "tést.üser@example.com", + expected: "tést_üser", + }, + { + email: "test$.user@example.com", + expected: "test__user", + }, + { + email: `jöhn.dœ@domain.com`, // Using non-ASCII characters. + expected: "jöhn_dœ", + }, + { + email: `first+tag@email.com`, // The plus (+) sign is used for "sub-addressing" in some email services. + expected: "first_tag", + }, + { + email: `email@sub.domain.com`, // Using a sub-domain. + expected: "email", + }, + { + email: `"_quoted"@domain.com`, // Quoted strings can be part of the local-part. + expected: "__quoted_", + }, + { + email: `name-o'mally@website.org`, // Single quote in the local-part. + expected: "name_o_mally", + }, + { + email: `user%domain@external.com`, // Percent sign can be used for email routing in legacy systems. + expected: "user_domain", + }, + { + email: `long.name.with.dots@domain.net`, // Multiple dots in the local-part. + expected: "long_name_with_dots", + }, + { + email: `me&you@together.com`, // Using an ampersand (&) in the local-part. + expected: "me_you", + }, + { + email: `user!def!xyz@domain.org`, // The exclamation mark can be valid in some legacy systems. + expected: "user_def_xyz", + }, + { + email: `admin@ιντερνετ.com`, // Domain in non-ASCII characters (IDN or Internationalized Domain Name). + expected: "admin", + }, + } + + for _, tt := range tests { + assert.Equal(t, tt.expected, GetShortUserName(tt.email)) + } +} diff --git a/libs/template/helpers.go b/libs/template/helpers.go index 7f306a3a..537fadb1 100644 --- a/libs/template/helpers.go +++ b/libs/template/helpers.go @@ -98,6 +98,16 @@ func loadHelpers(ctx context.Context) template.FuncMap { } return result, nil }, + "short_name": func() (string, error) { + if cachedUser == nil { + var err error + cachedUser, err = w.CurrentUser.Me(ctx) + if err != nil { + return "", err + } + } + return auth.GetShortUserName(cachedUser.UserName), nil + }, "is_service_principal": func() (bool, error) { if cachedIsServicePrincipal != nil { return *cachedIsServicePrincipal, nil