diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 4b0040728..1881c4ca0 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -15,7 +15,9 @@ import ( "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/git" + "github.com/databricks/cli/libs/telemetry" "github.com/databricks/cli/libs/template" + "github.com/databricks/databricks-sdk-go/client" "github.com/spf13/cobra" ) @@ -196,7 +198,26 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization") cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") - cmd.PreRunE = root.MustWorkspaceClient + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + // Configure the logger to send telemetry to Databricks. + ctx := telemetry.ContextWithLogger(cmd.Context()) + cmd.SetContext(ctx) + + return root.MustWorkspaceClient(cmd, args) + } + + cmd.PostRun = func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + w := root.WorkspaceClient(ctx) + apiClient, err := client.New(w.Config) + if err != nil { + // Uploading telemetry is best effort. Do not error. + return + } + + telemetry.Flush(cmd.Context(), apiClient) + } + cmd.RunE = func(cmd *cobra.Command, args []string) error { if tag != "" && branch != "" { return errors.New("only one of --tag or --branch can be specified") diff --git a/cmd/root/root.go b/cmd/root/root.go index 20079a0bb..3b37d0176 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -12,8 +12,6 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/log" - "github.com/databricks/cli/libs/telemetry" - "github.com/databricks/databricks-sdk-go/client" "github.com/spf13/cobra" ) @@ -54,9 +52,6 @@ func New(ctx context.Context) *cobra.Command { return err } - // Configure the logger to send telemetry to Databricks. - ctx = telemetry.NewContext(ctx) - logger := log.GetLogger(ctx) logger.Info("start", slog.String("version", build.GetInfo().Version), @@ -89,18 +84,6 @@ func New(ctx context.Context) *cobra.Command { return nil } - cmd.PersistentPostRun = func(cmd *cobra.Command, args []string) { - ctx := cmd.Context() - w := WorkspaceClient(ctx) - apiClient, err := client.New(w.Config) - if err != nil { - // Uploading telemetry is best effort. Do not error. - return - } - - telemetry.Flush(cmd.Context(), apiClient) - } - cmd.SetFlagErrorFunc(flagErrorFunc) cmd.SetVersionTemplate("Databricks CLI v{{.Version}}\n") return cmd diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index f5c263ca3..f80f6f8f3 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/iamutil" + "github.com/databricks/cli/libs/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,6 +43,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W + // Configure a telemetry logger in the context. + ctx = telemetry.ContextWithLogger(ctx) + tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -64,6 +68,19 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + // Assert the telemetry payload is correctly logged. + logs, err := telemetry.GetLogs(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(logs)) + event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + assert.Equal(t, event.TemplateName, "mlops-stacks") + // Enum values should be present in the telemetry payload. + assert.Equal(t, event.TemplateEnumArgs["input_include_models_in_unity_catalog"], "no") + assert.Equal(t, event.TemplateEnumArgs["input_cloud"], strings.ToLower(env)) + // Freeform strings should not be present in the telemetry payload. + assert.NotContains(t, event.TemplateEnumArgs, "input_project_name") + assert.NotContains(t, event.TemplateEnumArgs, "input_root_dir") + // Assert that the README.md file was created contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) assert.Contains(t, contents, fmt.Sprintf("# %s", projectName)) @@ -99,6 +116,139 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName)) } +func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { + projectName := testutil.RandomName("name_") + + tcases := []struct { + name string + args map[string]string + expectedArgs map[string]string + }{ + { + name: "dbt-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("dbt-sql-%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + }, + { + name: "default-python", + args: map[string]string{ + "project_name": fmt.Sprintf("default_python_%s", projectName), + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + expectedArgs: map[string]string{ + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + }, + { + name: "default-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("sql_project_%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + }, + } + + for _, tc := range tcases { + ctx, _ := acc.WorkspaceTest(t) + + // Configure a telemetry logger in the context. + ctx = telemetry.ContextWithLogger(ctx) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + + // Create a config file with the project name and root dir + initConfig := tc.args + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "config.json"), b, 0o644) + require.NoError(t, err) + + // Run bundle init + assert.NoDirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tc.name, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + + // Assert the telemetry payload is correctly logged. + logs, err := telemetry.GetLogs(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(logs)) + event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + assert.Equal(t, event.TemplateName, tc.name) + assert.Equal(t, event.TemplateEnumArgs, tc.expectedArgs) + } +} + +func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { + ctx, _ := acc.WorkspaceTest(t) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + tmpDir3 := t.TempDir() + + err := os.Mkdir(filepath.Join(tmpDir1, "template"), 0o755) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "template", "foo.txt.tmpl"), []byte("doesnotmatter"), 0o644) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "databricks_template_schema.json"), []byte(` +{ + "properties": { + "a": { + "description": "whatever", + "type": "string" + }, + "b": { + "description": "whatever", + "type": "string", + "enum": ["yes", "no"] + } + } +} +`), 0o644) + require.NoError(t, err) + + // Create a config file with the project name and root dir + initConfig := map[string]string{ + "a": "v1", + "b": "yes", + } + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644) + require.NoError(t, err) + + // Configure a telemetry logger in the context. + ctx = telemetry.ContextWithLogger(ctx) + + // Run bundle init. + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) + + // Assert the telemetry payload is correctly logged. For custom templates we should + // + logs, err := telemetry.GetLogs(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(logs)) + event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + assert.Equal(t, event.TemplateName, "custom") + assert.Nil(t, event.TemplateEnumArgs) +} + func TestBundleInitHelpers(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W diff --git a/integration/libs/telemetry/telemetry_test.go b/integration/libs/telemetry/telemetry_test.go index ba63aba9c..0c02c15d9 100644 --- a/integration/libs/telemetry/telemetry_test.go +++ b/integration/libs/telemetry/telemetry_test.go @@ -32,7 +32,7 @@ func (wrapper *apiClientWrapper) Do(ctx context.Context, method, path string, func TestTelemetryLogger(t *testing.T) { ctx, w := acc.WorkspaceTest(t) - ctx = telemetry.NewContext(ctx) + ctx = telemetry.ContextWithLogger(ctx) // Extend the maximum wait time for the telemetry flush just for this test. telemetry.MaxAdditionalWaitTime = 1 * time.Hour diff --git a/libs/telemetry/context.go b/libs/telemetry/context.go index 5625825d8..9ea913f5a 100644 --- a/libs/telemetry/context.go +++ b/libs/telemetry/context.go @@ -10,10 +10,12 @@ type telemetryLogger int // Key to store the telemetry logger in the context var telemetryLoggerKey telemetryLogger -func NewContext(ctx context.Context) context.Context { +func ContextWithLogger(ctx context.Context) context.Context { _, ok := ctx.Value(telemetryLoggerKey).(*logger) if ok { - panic("telemetry logger already exists in the context") + // If a logger is already configured in the context, do not set a new one. + // This is useful for testing. + return ctx } return context.WithValue(ctx, telemetryLoggerKey, &logger{protoLogs: []string{}}) diff --git a/libs/telemetry/logger.go b/libs/telemetry/logger.go index ce652d675..84e828b3e 100644 --- a/libs/telemetry/logger.go +++ b/libs/telemetry/logger.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "slices" "time" "github.com/databricks/cli/libs/log" @@ -42,6 +41,23 @@ type logger struct { protoLogs []string } +// Only to be used in tests to introspect the telemetry logs that are queued +// to be flushed. +func GetLogs(ctx context.Context) ([]FrontendLog, error) { + l := fromContext(ctx) + res := []FrontendLog{} + + for _, log := range l.protoLogs { + frontendLog := FrontendLog{} + err := json.Unmarshal([]byte(log), &frontendLog) + if err != nil { + return nil, fmt.Errorf("error unmarshalling the telemetry event: %v", err) + } + res = append(res, frontendLog) + } + return res, nil +} + // Maximum additional time to wait for the telemetry event to flush. We expect the flush // method to be called when the CLI command is about to exist, so this caps the maximum // additional time the user will experience because of us logging CLI telemetry. diff --git a/libs/telemetry/logger_test.go b/libs/telemetry/logger_test.go index 794b79600..d6fc30f24 100644 --- a/libs/telemetry/logger_test.go +++ b/libs/telemetry/logger_test.go @@ -65,7 +65,7 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) { uuid.SetRand(nil) }) - ctx := NewContext(context.Background()) + ctx := ContextWithLogger(context.Background()) for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} { err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{ @@ -99,7 +99,7 @@ func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) { uuid.SetRand(nil) }) - ctx := NewContext(context.Background()) + ctx := ContextWithLogger(context.Background()) for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} { err := Log(ctx, FrontendLogEntry{DatabricksCliLog: DatabricksCliLog{ diff --git a/libs/template/materialize.go b/libs/template/materialize.go index c0961b486..3ef757735 100644 --- a/libs/template/materialize.go +++ b/libs/template/materialize.go @@ -110,7 +110,7 @@ func (t *Template) logTelemetry(ctx context.Context) error { // Only log telemetry input for Databricks owned templates. This is to prevent // accidentally collecting PUII from custom user templates. templateEnumArgs := map[string]string{} - if !t.IsDatabricksOwned { + if t.IsDatabricksOwned { templateEnumArgs = t.config.enumValues() }