Add short_name helper function to bundle init templates (#1167)

## Changes
Adds the short_name helper function. short_name is useful when templates
do not want to print the full userName (typically email or service
principal application-id) of the current user.

## Tests
Integration test. Also adds integration tests for other helper functions
that interact with the Databricks API.
This commit is contained in:
shreyas-goenka 2024-02-01 22:16:07 +05:30 committed by GitHub
parent 0b3eeb8e54
commit cb3ad737f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 79 deletions

View File

@ -2,12 +2,11 @@ package mutator
import (
"context"
"strings"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/tags"
"github.com/databricks/cli/libs/textutil"
)
type populateCurrentUser struct{}
@ -33,7 +32,7 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error
}
b.Config.Workspace.CurrentUser = &config.User{
ShortName: getShortUserName(me.UserName),
ShortName: auth.GetShortUserName(me.UserName),
User: me,
}
@ -42,11 +41,3 @@ func (m *populateCurrentUser) Apply(ctx context.Context, b *bundle.Bundle) error
return nil
}
// Get a short-form username, based on the user's primary email address.
// We leave the full range of unicode letters in tact, but remove all "special" characters,
// including dots, which are not supported in e.g. experiment names.
func getShortUserName(emailAddress string) string {
local, _, _ := strings.Cut(emailAddress, "@")
return textutil.NormalizeString(local)
}

View File

@ -2,75 +2,8 @@ package mutator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPopulateCurrentUser(t *testing.T) {
// We need to implement workspace client mocking to implement this test.
}
func TestGetShortUserName(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
email: "test.user.1234@example.com",
expected: "test_user_1234",
},
{
email: "tést.üser@example.com",
expected: "tést_üser",
},
{
email: "test$.user@example.com",
expected: "test__user",
},
{
email: `jöhn.dœ@domain.com`, // Using non-ASCII characters.
expected: "jöhn_dœ",
},
{
email: `first+tag@email.com`, // The plus (+) sign is used for "sub-addressing" in some email services.
expected: "first_tag",
},
{
email: `email@sub.domain.com`, // Using a sub-domain.
expected: "email",
},
{
email: `"_quoted"@domain.com`, // Quoted strings can be part of the local-part.
expected: "__quoted_",
},
{
email: `name-o'mally@website.org`, // Single quote in the local-part.
expected: "name_o_mally",
},
{
email: `user%domain@external.com`, // Percent sign can be used for email routing in legacy systems.
expected: "user_domain",
},
{
email: `long.name.with.dots@domain.net`, // Multiple dots in the local-part.
expected: "long_name_with_dots",
},
{
email: `me&you@together.com`, // Using an ampersand (&) in the local-part.
expected: "me_you",
},
{
email: `user!def!xyz@domain.org`, // The exclamation mark can be valid in some legacy systems.
expected: "user_def_xyz",
},
{
email: `admin@ιντερνετ.com`, // Domain in non-ASCII characters (IDN or Internationalized Domain Name).
expected: "admin",
},
}
for _, tt := range tests {
assert.Equal(t, tt.expected, getShortUserName(tt.email))
}
}

View File

@ -1,9 +1,16 @@
package internal
import (
"context"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/databricks-sdk-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAccBundleInitErrorOnUnknownFields(t *testing.T) {
@ -13,3 +20,69 @@ func TestAccBundleInitErrorOnUnknownFields(t *testing.T) {
_, _, err := RequireErrorRun(t, "bundle", "init", "./testdata/init/field-does-not-exist", "--output-dir", tmpDir)
assert.EqualError(t, err, "failed to compute file content for bar.tmpl. variable \"does_not_exist\" not defined")
}
func TestAccBundleInitHelpers(t *testing.T) {
env := GetEnvOrSkipTest(t, "CLOUD_ENV")
t.Log(env)
w, err := databricks.NewWorkspaceClient(&databricks.Config{})
require.NoError(t, err)
me, err := w.CurrentUser.Me(context.Background())
require.NoError(t, err)
var smallestNode string
switch env {
case "azure":
smallestNode = "Standard_D3_v2"
case "gcp":
smallestNode = "n1-standard-4"
default:
smallestNode = "i3.xlarge"
}
tests := []struct {
funcName string
expected string
}{
{
funcName: "{{short_name}}",
expected: auth.GetShortUserName(me.UserName),
},
{
funcName: "{{user_name}}",
expected: me.UserName,
},
{
funcName: "{{workspace_host}}",
expected: w.Config.Host,
},
{
funcName: "{{is_service_principal}}",
expected: strconv.FormatBool(auth.IsServicePrincipal(me.Id)),
},
{
funcName: "{{smallest_node_type}}",
expected: smallestNode,
},
}
for _, test := range tests {
// Setup template to test the helper function.
tmpDir := t.TempDir()
tmpDir2 := t.TempDir()
err := os.Mkdir(filepath.Join(tmpDir, "template"), 0755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(tmpDir, "template", "foo.txt.tmpl"), []byte(test.funcName), 0644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(tmpDir, "databricks_template_schema.json"), []byte("{}"), 0644)
require.NoError(t, err)
// Run bundle init.
RequireSuccessfulRun(t, "bundle", "init", tmpDir, "--output-dir", tmpDir2)
// Assert that the helper function was correctly computed.
assertLocalFileContents(t, filepath.Join(tmpDir2, "foo.txt"), test.expected)
}
}

View File

@ -81,7 +81,6 @@ func setupWorkspaceImportExportTest(t *testing.T) (context.Context, filer.Filer,
return ctx, f, tmpdir
}
// TODO: add tests for the progress event output logs: https://github.com/databricks/cli/issues/447
func assertLocalFileContents(t *testing.T, path string, content string) {
require.FileExists(t, path)
b, err := os.ReadFile(path)

15
libs/auth/user.go Normal file
View File

@ -0,0 +1,15 @@
package auth
import (
"strings"
"github.com/databricks/cli/libs/textutil"
)
// Get a short-form username, based on the user's primary email address.
// We leave the full range of unicode letters in tact, but remove all "special" characters,
// including dots, which are not supported in e.g. experiment names.
func GetShortUserName(emailAddress string) string {
local, _, _ := strings.Cut(emailAddress, "@")
return textutil.NormalizeString(local)
}

72
libs/auth/user_test.go Normal file
View File

@ -0,0 +1,72 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetShortUserName(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
email: "test.user.1234@example.com",
expected: "test_user_1234",
},
{
email: "tést.üser@example.com",
expected: "tést_üser",
},
{
email: "test$.user@example.com",
expected: "test__user",
},
{
email: `jöhn.dœ@domain.com`, // Using non-ASCII characters.
expected: "jöhn_dœ",
},
{
email: `first+tag@email.com`, // The plus (+) sign is used for "sub-addressing" in some email services.
expected: "first_tag",
},
{
email: `email@sub.domain.com`, // Using a sub-domain.
expected: "email",
},
{
email: `"_quoted"@domain.com`, // Quoted strings can be part of the local-part.
expected: "__quoted_",
},
{
email: `name-o'mally@website.org`, // Single quote in the local-part.
expected: "name_o_mally",
},
{
email: `user%domain@external.com`, // Percent sign can be used for email routing in legacy systems.
expected: "user_domain",
},
{
email: `long.name.with.dots@domain.net`, // Multiple dots in the local-part.
expected: "long_name_with_dots",
},
{
email: `me&you@together.com`, // Using an ampersand (&) in the local-part.
expected: "me_you",
},
{
email: `user!def!xyz@domain.org`, // The exclamation mark can be valid in some legacy systems.
expected: "user_def_xyz",
},
{
email: `admin@ιντερνετ.com`, // Domain in non-ASCII characters (IDN or Internationalized Domain Name).
expected: "admin",
},
}
for _, tt := range tests {
assert.Equal(t, tt.expected, GetShortUserName(tt.email))
}
}

View File

@ -98,6 +98,16 @@ func loadHelpers(ctx context.Context) template.FuncMap {
}
return result, nil
},
"short_name": func() (string, error) {
if cachedUser == nil {
var err error
cachedUser, err = w.CurrentUser.Me(ctx)
if err != nil {
return "", err
}
}
return auth.GetShortUserName(cachedUser.UserName), nil
},
"is_service_principal": func() (bool, error) {
if cachedIsServicePrincipal != nil {
return *cachedIsServicePrincipal, nil