diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 1e698c624..c5074a061 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -201,7 +201,7 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // Configure the logger to send telemetry to Databricks. - ctx := telemetry.ContextWithLogger(cmd.Context()) + ctx := telemetry.WithDefaultLogger(cmd.Context()) cmd.SetContext(ctx) return root.MustWorkspaceClient(cmd, args) diff --git a/integration/bundle/helpers_test.go b/integration/bundle/helpers_test.go index 36a3231db..bcca56ea9 100644 --- a/integration/bundle/helpers_test.go +++ b/integration/bundle/helpers_test.go @@ -29,7 +29,6 @@ const defaultSparkVersion = "13.3.x-snapshot-scala2.12" func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any) string { bundleRoot := t.TempDir() - ctx = telemetry.ContextWithLogger(ctx) return initTestTemplateWithBundleRoot(t, ctx, templateName, config, bundleRoot) } @@ -38,10 +37,10 @@ func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, te configFilePath := writeConfigFile(t, config) - ctx = telemetry.ContextWithLogger(ctx) ctx = root.SetWorkspaceClient(ctx, nil) cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") ctx = cmdio.InContext(ctx, cmd) + ctx = telemetry.WithMockLogger(ctx) out, err := filer.NewLocalClient(bundleRoot) require.NoError(t, err) diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index 1bd7e0034..d22dc750d 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -46,7 +46,7 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { w := wt.W // Configure a telemetry logger in the context. - ctx = telemetry.ContextWithLogger(ctx) + ctx = telemetry.WithDefaultLogger(ctx) tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -71,10 +71,10 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) // Assert the telemetry payload is correctly logged. - logs := telemetry.GetLogs(ctx) + logs := telemetry.Introspect(ctx) require.NoError(t, err) require.Len(t, len(logs), 1) - event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + event := logs[0].BundleInitEvent assert.Equal(t, "mlops-stacks", event.TemplateName) get := func(key string) string { @@ -180,7 +180,7 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { ctx, _ := acc.WorkspaceTest(t) // Configure a telemetry logger in the context. - ctx = telemetry.ContextWithLogger(ctx) + ctx = telemetry.WithDefaultLogger(ctx) tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -198,10 +198,9 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) // Assert the telemetry payload is correctly logged. - logs := telemetry.GetLogs(ctx) - require.NoError(t, err) + logs := telemetry.Introspect(ctx) require.Len(t, len(logs), 1) - event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + event := logs[0].BundleInitEvent assert.Equal(t, event.TemplateName, tc.name) get := func(key string) string { @@ -260,17 +259,16 @@ func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { require.NoError(t, err) // Configure a telemetry logger in the context. - ctx = telemetry.ContextWithLogger(ctx) + ctx = telemetry.WithDefaultLogger(ctx) // Run bundle init. testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) // Assert the telemetry payload is correctly logged. For custom templates we should // never set template_enum_args. - logs := telemetry.GetLogs(ctx) - require.NoError(t, err) + logs := telemetry.Introspect(ctx) require.Len(t, len(logs), 1) - event := logs[0].Entry.DatabricksCliLog.BundleInitEvent + event := logs[0].BundleInitEvent assert.Equal(t, "custom", event.TemplateName) assert.Empty(t, event.TemplateEnumArgs) diff --git a/integration/libs/telemetry/telemetry_test.go b/integration/libs/telemetry/telemetry_test.go index 8d5948b4b..e6a0edcea 100644 --- a/integration/libs/telemetry/telemetry_test.go +++ b/integration/libs/telemetry/telemetry_test.go @@ -57,16 +57,17 @@ func TestTelemetryLogger(t *testing.T) { }, } - assert.Len(t, reflect.TypeOf(telemetry.DatabricksCliLog{}).NumField(), len(events), + assert.Equal(t, len(events), reflect.TypeOf(telemetry.DatabricksCliLog{}).NumField(), "Number of events should match the number of fields in DatabricksCliLog. Please add a new event to this test.") ctx, w := acc.WorkspaceTest(t) - ctx = telemetry.ContextWithLogger(ctx) + ctx = telemetry.WithDefaultLogger(ctx) // Extend the maximum wait time for the telemetry flush just for this test. + oldV := telemetry.MaxAdditionalWaitTime telemetry.MaxAdditionalWaitTime = 1 * time.Hour t.Cleanup(func() { - telemetry.MaxAdditionalWaitTime = 2 * time.Second + telemetry.MaxAdditionalWaitTime = oldV }) for _, event := range events { diff --git a/libs/telemetry/context.go b/libs/telemetry/context.go index 42c7ef870..ed8f7f710 100644 --- a/libs/telemetry/context.go +++ b/libs/telemetry/context.go @@ -2,6 +2,7 @@ package telemetry import ( "context" + "fmt" ) // Private type to store the telemetry logger in the context @@ -10,21 +11,36 @@ type telemetryLogger int // Key to store the telemetry logger in the context var telemetryLoggerKey telemetryLogger -func ContextWithLogger(ctx context.Context) context.Context { - _, ok := ctx.Value(telemetryLoggerKey).(*logger) - if ok { - // If a logger is already configured in the context, do not set a new one. - // This is useful for testing. - return ctx +func WithDefaultLogger(ctx context.Context) context.Context { + v := ctx.Value(telemetryLoggerKey) + if v != nil { + panic(fmt.Sprintf("telemetry logger already set in the context: %v", v)) } - return context.WithValue(ctx, telemetryLoggerKey, &logger{logs: []FrontendLog{}}) + return context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{logs: []FrontendLog{}}) } -func fromContext(ctx context.Context) *logger { - l, ok := ctx.Value(telemetryLoggerKey).(*logger) - if !ok { +func WithMockLogger(ctx context.Context) context.Context { + v := ctx.Value(telemetryLoggerKey) + if v != nil { + panic(fmt.Sprintf("telemetry logger already set in the context: %v", v)) + } + + return context.WithValue(ctx, telemetryLoggerKey, &mockLogger{}) +} + +func fromContext(ctx context.Context) Logger { + v := ctx.Value(telemetryLoggerKey) + if v == nil { panic("telemetry logger not found in the context") } - return l + + switch v.(type) { + case *defaultLogger: + return v.(*defaultLogger) + case *mockLogger: + return v.(*mockLogger) + default: + panic(fmt.Sprintf("unexpected telemetry logger type: %T", v)) + } } diff --git a/libs/telemetry/logger.go b/libs/telemetry/logger.go index 9fb907e83..43abf5dd3 100644 --- a/libs/telemetry/logger.go +++ b/libs/telemetry/logger.go @@ -17,9 +17,24 @@ type DatabricksApiClient interface { visitors ...func(*http.Request) error) error } -func Log(ctx context.Context, event DatabricksCliLog) { - l := fromContext(ctx) +type Logger interface { + // Record a telemetry event, to be flushed later. + Log(event DatabricksCliLog) + // Flush all the telemetry events that have been logged so far. We expect + // this to be called once per CLI command for the default logger. + Flush(ctx context.Context, apiClient DatabricksApiClient) + + // This function is meant to be only to be used in tests to introspect + // the telemetry logs that have been logged so far. + Introspect() []DatabricksCliLog +} + +type defaultLogger struct { + logs []FrontendLog +} + +func (l *defaultLogger) Log(event DatabricksCliLog) { l.logs = append(l.logs, FrontendLog{ // The telemetry endpoint deduplicates logs based on the FrontendLogEventID. // This it's important to generate a unique ID for each log event. @@ -30,17 +45,6 @@ func Log(ctx context.Context, event DatabricksCliLog) { }) } -type logger struct { - logs []FrontendLog -} - -// This function is meant to be only to be used in tests to introspect the telemetry logs -// that have been logged so far. -func GetLogs(ctx context.Context) []FrontendLog { - l := fromContext(ctx) - return l.logs -} - // Maximum additional time to wait for the telemetry event to flush. We expect the flush // method to be called when the CLI command is about to exist, so this caps the maximum // additional time the user will experience because of us logging CLI telemetry. @@ -50,11 +54,10 @@ var MaxAdditionalWaitTime = 3 * time.Second // right about as the CLI command is about to exit. The API endpoint can handle // payloads with ~1000 events easily. Thus we log all the events at once instead of // batching the logs across multiple API calls. -func Flush(ctx context.Context, apiClient DatabricksApiClient) { +func (l *defaultLogger) Flush(ctx context.Context, apiClient DatabricksApiClient) { // Set a maximum time to wait for the telemetry event to flush. ctx, cancel := context.WithTimeout(ctx, MaxAdditionalWaitTime) defer cancel() - l := fromContext(ctx) if len(l.logs) == 0 { log.Debugf(ctx, "No telemetry events to flush") @@ -112,3 +115,22 @@ func Flush(ctx context.Context, apiClient DatabricksApiClient) { return } } + +func (l *defaultLogger) Introspect() []DatabricksCliLog { + panic("not implemented") +} + +func Log(ctx context.Context, event DatabricksCliLog) { + l := fromContext(ctx) + l.Log(event) +} + +func Flush(ctx context.Context, apiClient DatabricksApiClient) { + l := fromContext(ctx) + l.Flush(ctx, apiClient) +} + +func Introspect(ctx context.Context) []DatabricksCliLog { + l := fromContext(ctx) + return l.Introspect() +} diff --git a/libs/telemetry/logger_test.go b/libs/telemetry/logger_test.go index 05be23aeb..adbb165a9 100644 --- a/libs/telemetry/logger_test.go +++ b/libs/telemetry/logger_test.go @@ -5,7 +5,6 @@ import ( "math/rand" "net/http" "testing" - "time" "github.com/databricks/cli/libs/telemetry/events" "github.com/google/uuid" @@ -64,7 +63,7 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) { uuid.SetRand(nil) }) - ctx := ContextWithLogger(context.Background()) + ctx := WithDefaultLogger(context.Background()) for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} { Log(ctx, DatabricksCliLog{ @@ -82,9 +81,10 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) { func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) { // Set the maximum additional wait time to 0 to ensure that the Flush method times out immediately. + oldV := MaxAdditionalWaitTime MaxAdditionalWaitTime = 0 t.Cleanup(func() { - MaxAdditionalWaitTime = 2 * time.Second + MaxAdditionalWaitTime = oldV }) mockClient := &mockDatabricksClient{ @@ -97,7 +97,7 @@ func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) { uuid.SetRand(nil) }) - ctx := ContextWithLogger(context.Background()) + ctx := WithDefaultLogger(context.Background()) for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} { Log(ctx, DatabricksCliLog{ diff --git a/libs/telemetry/mock_logger.go b/libs/telemetry/mock_logger.go new file mode 100644 index 000000000..e673b3f89 --- /dev/null +++ b/libs/telemetry/mock_logger.go @@ -0,0 +1,28 @@ +package telemetry + +import "context" + +// TODO CONTINUE: +// 1. Continue cleaning up the telemetry PR. Cleanup the interfaces +// and add a good mock / testing support by storing this in the context. +// +// 2. Test the logging is being done correctly. All componets work fine. +// +// 3. Ask once more for review. Also announce plans to do this by separately +// spawning a new process. We can add a new CLI command in the executable to +// do so. +type mockLogger struct { + events []DatabricksCliLog +} + +func (l *mockLogger) Log(event DatabricksCliLog) { + l.events = append(l.events, event) +} + +func (l *mockLogger) Flush(ctx context.Context, apiClient DatabricksApiClient) { + // Do nothing +} + +func (l *mockLogger) Introspect() []DatabricksCliLog { + return l.events +}