mirror of https://github.com/databricks/cli.git
add integration tests and bug fixes
This commit is contained in:
parent
ecb977f4ed
commit
44d43fccb7
|
@ -15,7 +15,9 @@ import (
|
||||||
"github.com/databricks/cli/libs/dbr"
|
"github.com/databricks/cli/libs/dbr"
|
||||||
"github.com/databricks/cli/libs/filer"
|
"github.com/databricks/cli/libs/filer"
|
||||||
"github.com/databricks/cli/libs/git"
|
"github.com/databricks/cli/libs/git"
|
||||||
|
"github.com/databricks/cli/libs/telemetry"
|
||||||
"github.com/databricks/cli/libs/template"
|
"github.com/databricks/cli/libs/template"
|
||||||
|
"github.com/databricks/databricks-sdk-go/client"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,7 +198,26 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
|
||||||
cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization")
|
cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization")
|
||||||
cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization")
|
cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization")
|
||||||
|
|
||||||
cmd.PreRunE = root.MustWorkspaceClient
|
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Configure the logger to send telemetry to Databricks.
|
||||||
|
ctx := telemetry.ContextWithLogger(cmd.Context())
|
||||||
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
|
return root.MustWorkspaceClient(cmd, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PostRun = func(cmd *cobra.Command, args []string) {
|
||||||
|
ctx := cmd.Context()
|
||||||
|
w := root.WorkspaceClient(ctx)
|
||||||
|
apiClient, err := client.New(w.Config)
|
||||||
|
if err != nil {
|
||||||
|
// Uploading telemetry is best effort. Do not error.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
telemetry.Flush(cmd.Context(), apiClient)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
cmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||||
if tag != "" && branch != "" {
|
if tag != "" && branch != "" {
|
||||||
return errors.New("only one of --tag or --branch can be specified")
|
return errors.New("only one of --tag or --branch can be specified")
|
||||||
|
|
|
@ -12,8 +12,6 @@ import (
|
||||||
"github.com/databricks/cli/libs/cmdio"
|
"github.com/databricks/cli/libs/cmdio"
|
||||||
"github.com/databricks/cli/libs/dbr"
|
"github.com/databricks/cli/libs/dbr"
|
||||||
"github.com/databricks/cli/libs/log"
|
"github.com/databricks/cli/libs/log"
|
||||||
"github.com/databricks/cli/libs/telemetry"
|
|
||||||
"github.com/databricks/databricks-sdk-go/client"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,9 +52,6 @@ func New(ctx context.Context) *cobra.Command {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure the logger to send telemetry to Databricks.
|
|
||||||
ctx = telemetry.NewContext(ctx)
|
|
||||||
|
|
||||||
logger := log.GetLogger(ctx)
|
logger := log.GetLogger(ctx)
|
||||||
logger.Info("start",
|
logger.Info("start",
|
||||||
slog.String("version", build.GetInfo().Version),
|
slog.String("version", build.GetInfo().Version),
|
||||||
|
@ -89,18 +84,6 @@ func New(ctx context.Context) *cobra.Command {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.PersistentPostRun = func(cmd *cobra.Command, args []string) {
|
|
||||||
ctx := cmd.Context()
|
|
||||||
w := WorkspaceClient(ctx)
|
|
||||||
apiClient, err := client.New(w.Config)
|
|
||||||
if err != nil {
|
|
||||||
// Uploading telemetry is best effort. Do not error.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
telemetry.Flush(cmd.Context(), apiClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.SetFlagErrorFunc(flagErrorFunc)
|
cmd.SetFlagErrorFunc(flagErrorFunc)
|
||||||
cmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n")
|
cmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n")
|
||||||
return cmd
|
return cmd
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/databricks/cli/internal/testcli"
|
"github.com/databricks/cli/internal/testcli"
|
||||||
"github.com/databricks/cli/internal/testutil"
|
"github.com/databricks/cli/internal/testutil"
|
||||||
"github.com/databricks/cli/libs/iamutil"
|
"github.com/databricks/cli/libs/iamutil"
|
||||||
|
"github.com/databricks/cli/libs/telemetry"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -42,6 +43,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) {
|
||||||
ctx, wt := acc.WorkspaceTest(t)
|
ctx, wt := acc.WorkspaceTest(t)
|
||||||
w := wt.W
|
w := wt.W
|
||||||
|
|
||||||
|
// Configure a telemetry logger in the context.
|
||||||
|
ctx = telemetry.ContextWithLogger(ctx)
|
||||||
|
|
||||||
tmpDir1 := t.TempDir()
|
tmpDir1 := t.TempDir()
|
||||||
tmpDir2 := t.TempDir()
|
tmpDir2 := t.TempDir()
|
||||||
|
|
||||||
|
@ -64,6 +68,19 @@ func TestBundleInitOnMlopsStacks(t *testing.T) {
|
||||||
assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md"))
|
assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md"))
|
||||||
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json"))
|
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json"))
|
||||||
|
|
||||||
|
// Assert the telemetry payload is correctly logged.
|
||||||
|
logs, err := telemetry.GetLogs(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, len(logs))
|
||||||
|
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
|
||||||
|
assert.Equal(t, event.TemplateName, "mlops-stacks")
|
||||||
|
// Enum values should be present in the telemetry payload.
|
||||||
|
assert.Equal(t, event.TemplateEnumArgs["input_include_models_in_unity_catalog"], "no")
|
||||||
|
assert.Equal(t, event.TemplateEnumArgs["input_cloud"], strings.ToLower(env))
|
||||||
|
// Freeform strings should not be present in the telemetry payload.
|
||||||
|
assert.NotContains(t, event.TemplateEnumArgs, "input_project_name")
|
||||||
|
assert.NotContains(t, event.TemplateEnumArgs, "input_root_dir")
|
||||||
|
|
||||||
// Assert that the README.md file was created
|
// Assert that the README.md file was created
|
||||||
contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md"))
|
contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md"))
|
||||||
assert.Contains(t, contents, fmt.Sprintf("# %s", projectName))
|
assert.Contains(t, contents, fmt.Sprintf("# %s", projectName))
|
||||||
|
@ -99,6 +116,139 @@ func TestBundleInitOnMlopsStacks(t *testing.T) {
|
||||||
assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName))
|
assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) {
|
||||||
|
projectName := testutil.RandomName("name_")
|
||||||
|
|
||||||
|
tcases := []struct {
|
||||||
|
name string
|
||||||
|
args map[string]string
|
||||||
|
expectedArgs map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "dbt-sql",
|
||||||
|
args: map[string]string{
|
||||||
|
"project_name": fmt.Sprintf("dbt-sql-%s", projectName),
|
||||||
|
"http_path": "/sql/1.0/warehouses/id",
|
||||||
|
"default_catalog": "abcd",
|
||||||
|
"personal_schemas": "yes, use a schema based on the current user name during development",
|
||||||
|
},
|
||||||
|
expectedArgs: map[string]string{
|
||||||
|
"personal_schemas": "yes, use a schema based on the current user name during development",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default-python",
|
||||||
|
args: map[string]string{
|
||||||
|
"project_name": fmt.Sprintf("default_python_%s", projectName),
|
||||||
|
"include_notebook": "yes",
|
||||||
|
"include_dlt": "yes",
|
||||||
|
"include_python": "no",
|
||||||
|
},
|
||||||
|
expectedArgs: map[string]string{
|
||||||
|
"include_notebook": "yes",
|
||||||
|
"include_dlt": "yes",
|
||||||
|
"include_python": "no",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default-sql",
|
||||||
|
args: map[string]string{
|
||||||
|
"project_name": fmt.Sprintf("sql_project_%s", projectName),
|
||||||
|
"http_path": "/sql/1.0/warehouses/id",
|
||||||
|
"default_catalog": "abcd",
|
||||||
|
"personal_schemas": "yes, automatically use a schema based on the current user name during development",
|
||||||
|
},
|
||||||
|
expectedArgs: map[string]string{
|
||||||
|
"personal_schemas": "yes, automatically use a schema based on the current user name during development",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcases {
|
||||||
|
ctx, _ := acc.WorkspaceTest(t)
|
||||||
|
|
||||||
|
// Configure a telemetry logger in the context.
|
||||||
|
ctx = telemetry.ContextWithLogger(ctx)
|
||||||
|
|
||||||
|
tmpDir1 := t.TempDir()
|
||||||
|
tmpDir2 := t.TempDir()
|
||||||
|
|
||||||
|
// Create a config file with the project name and root dir
|
||||||
|
initConfig := tc.args
|
||||||
|
b, err := json.Marshal(initConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = os.WriteFile(filepath.Join(tmpDir1, "config.json"), b, 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Run bundle init
|
||||||
|
assert.NoDirExists(t, filepath.Join(tmpDir2, tc.args["project_name"]))
|
||||||
|
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tc.name, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json"))
|
||||||
|
assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"]))
|
||||||
|
|
||||||
|
// Assert the telemetry payload is correctly logged.
|
||||||
|
logs, err := telemetry.GetLogs(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, len(logs))
|
||||||
|
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
|
||||||
|
assert.Equal(t, event.TemplateName, tc.name)
|
||||||
|
assert.Equal(t, event.TemplateEnumArgs, tc.expectedArgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBundleInitTelemetryForCustomTemplates(t *testing.T) {
|
||||||
|
ctx, _ := acc.WorkspaceTest(t)
|
||||||
|
|
||||||
|
tmpDir1 := t.TempDir()
|
||||||
|
tmpDir2 := t.TempDir()
|
||||||
|
tmpDir3 := t.TempDir()
|
||||||
|
|
||||||
|
err := os.Mkdir(filepath.Join(tmpDir1, "template"), 0o755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = os.WriteFile(filepath.Join(tmpDir1, "template", "foo.txt.tmpl"), []byte("doesnotmatter"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = os.WriteFile(filepath.Join(tmpDir1, "databricks_template_schema.json"), []byte(`
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"description": "whatever",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"description": "whatever",
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["yes", "no"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a config file with the project name and root dir
|
||||||
|
initConfig := map[string]string{
|
||||||
|
"a": "v1",
|
||||||
|
"b": "yes",
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(initConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure a telemetry logger in the context.
|
||||||
|
ctx = telemetry.ContextWithLogger(ctx)
|
||||||
|
|
||||||
|
// Run bundle init.
|
||||||
|
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json"))
|
||||||
|
|
||||||
|
// Assert the telemetry payload is correctly logged. For custom templates we should
|
||||||
|
//
|
||||||
|
logs, err := telemetry.GetLogs(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, len(logs))
|
||||||
|
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
|
||||||
|
assert.Equal(t, event.TemplateName, "custom")
|
||||||
|
assert.Nil(t, event.TemplateEnumArgs)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBundleInitHelpers(t *testing.T) {
|
func TestBundleInitHelpers(t *testing.T) {
|
||||||
ctx, wt := acc.WorkspaceTest(t)
|
ctx, wt := acc.WorkspaceTest(t)
|
||||||
w := wt.W
|
w := wt.W
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (wrapper *apiClientWrapper) Do(ctx context.Context, method, path string,
|
||||||
|
|
||||||
func TestTelemetryLogger(t *testing.T) {
|
func TestTelemetryLogger(t *testing.T) {
|
||||||
ctx, w := acc.WorkspaceTest(t)
|
ctx, w := acc.WorkspaceTest(t)
|
||||||
ctx = telemetry.NewContext(ctx)
|
ctx = telemetry.ContextWithLogger(ctx)
|
||||||
|
|
||||||
// Extend the maximum wait time for the telemetry flush just for this test.
|
// Extend the maximum wait time for the telemetry flush just for this test.
|
||||||
telemetry.MaxAdditionalWaitTime = 1 * time.Hour
|
telemetry.MaxAdditionalWaitTime = 1 * time.Hour
|
||||||
|
|
|
@ -10,10 +10,12 @@ type telemetryLogger int
|
||||||
// Key to store the telemetry logger in the context
|
// Key to store the telemetry logger in the context
|
||||||
var telemetryLoggerKey telemetryLogger
|
var telemetryLoggerKey telemetryLogger
|
||||||
|
|
||||||
func NewContext(ctx context.Context) context.Context {
|
func ContextWithLogger(ctx context.Context) context.Context {
|
||||||
_, ok := ctx.Value(telemetryLoggerKey).(*logger)
|
_, ok := ctx.Value(telemetryLoggerKey).(*logger)
|
||||||
if ok {
|
if ok {
|
||||||
panic("telemetry logger already exists in the context")
|
// If a logger is already configured in the context, do not set a new one.
|
||||||
|
// This is useful for testing.
|
||||||
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.WithValue(ctx, telemetryLoggerKey, &logger{protoLogs: []string{}})
|
return context.WithValue(ctx, telemetryLoggerKey, &logger{protoLogs: []string{}})
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/databricks/cli/libs/log"
|
"github.com/databricks/cli/libs/log"
|
||||||
|
@ -42,6 +41,23 @@ type logger struct {
|
||||||
protoLogs []string
|
protoLogs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only to be used in tests to introspect the telemetry logs that are queued
|
||||||
|
// to be flushed.
|
||||||
|
func GetLogs(ctx context.Context) ([]FrontendLog, error) {
|
||||||
|
l := fromContext(ctx)
|
||||||
|
res := []FrontendLog{}
|
||||||
|
|
||||||
|
for _, log := range l.protoLogs {
|
||||||
|
frontendLog := FrontendLog{}
|
||||||
|
err := json.Unmarshal([]byte(log), &frontendLog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshalling the telemetry event: %v", err)
|
||||||
|
}
|
||||||
|
res = append(res, frontendLog)
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Maximum additional time to wait for the telemetry event to flush. We expect the flush
|
// Maximum additional time to wait for the telemetry event to flush. We expect the flush
|
||||||
// method to be called when the CLI command is about to exist, so this caps the maximum
|
// method to be called when the CLI command is about to exist, so this caps the maximum
|
||||||
// additional time the user will experience because of us logging CLI telemetry.
|
// additional time the user will experience because of us logging CLI telemetry.
|
||||||
|
|
|
@ -65,7 +65,7 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) {
|
||||||
uuid.SetRand(nil)
|
uuid.SetRand(nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx := NewContext(context.Background())
|
ctx := ContextWithLogger(context.Background())
|
||||||
|
|
||||||
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
|
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
|
||||||
err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{
|
err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{
|
||||||
|
@ -99,7 +99,7 @@ func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) {
|
||||||
uuid.SetRand(nil)
|
uuid.SetRand(nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx := NewContext(context.Background())
|
ctx := ContextWithLogger(context.Background())
|
||||||
|
|
||||||
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
|
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
|
||||||
err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{
|
err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{
|
||||||
|
|
|
@ -110,7 +110,7 @@ func (t *Template) logTelemetry(ctx context.Context) error {
|
||||||
// Only log telemetry input for Databricks owned templates. This is to prevent
|
// Only log telemetry input for Databricks owned templates. This is to prevent
|
||||||
// accidentally collecting PUII from custom user templates.
|
// accidentally collecting PUII from custom user templates.
|
||||||
templateEnumArgs := map[string]string{}
|
templateEnumArgs := map[string]string{}
|
||||||
if !t.IsDatabricksOwned {
|
if t.IsDatabricksOwned {
|
||||||
templateEnumArgs = t.config.enumValues()
|
templateEnumArgs = t.config.enumValues()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue