diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index d22dc750d..8f96d19e7 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -45,8 +45,8 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W - // Configure a telemetry logger in the context. - ctx = telemetry.WithDefaultLogger(ctx) + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -71,10 +71,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) // Assert the telemetry payload is correctly logged. - logs := telemetry.Introspect(ctx) - require.NoError(t, err) - require.Len(t, len(logs), 1) - event := logs[0].BundleInitEvent + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, telemetry.Introspect(ctx), 1) + event := tlmyEvents[0].BundleInitEvent assert.Equal(t, "mlops-stacks", event.TemplateName) get := func(key string) string { @@ -179,8 +178,8 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { for _, tc := range tcases { ctx, _ := acc.WorkspaceTest(t) - // Configure a telemetry logger in the context. - ctx = telemetry.WithDefaultLogger(ctx) + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -199,7 +198,7 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { // Assert the telemetry payload is correctly logged. logs := telemetry.Introspect(ctx) - require.Len(t, len(logs), 1) + require.Len(t, logs, 1) event := logs[0].BundleInitEvent assert.Equal(t, event.TemplateName, tc.name) @@ -258,17 +257,17 @@ func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644) require.NoError(t, err) - // Configure a telemetry logger in the context. - ctx = telemetry.WithDefaultLogger(ctx) + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) // Run bundle init. testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) // Assert the telemetry payload is correctly logged. For custom templates we should // never set template_enum_args. - logs := telemetry.Introspect(ctx) - require.Len(t, len(logs), 1) - event := logs[0].BundleInitEvent + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, len(tlmyEvents), 1) + event := tlmyEvents[0].BundleInitEvent assert.Equal(t, "custom", event.TemplateName) assert.Empty(t, event.TemplateEnumArgs) diff --git a/libs/telemetry/context.go b/libs/telemetry/context.go index ed8f7f710..8236fde04 100644 --- a/libs/telemetry/context.go +++ b/libs/telemetry/context.go @@ -11,15 +11,32 @@ type telemetryLogger int // Key to store the telemetry logger in the context var telemetryLoggerKey telemetryLogger +// TODO: Add tests for these methods. func WithDefaultLogger(ctx context.Context) context.Context { v := ctx.Value(telemetryLoggerKey) - if v != nil { - panic(fmt.Sprintf("telemetry logger already set in the context: %v", v)) + + // If no logger is set in the context, set the default logger. + if v == nil { + nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{logs: []FrontendLog{}}) + return nctx } - return context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{logs: []FrontendLog{}}) + switch v.(type) { + case *defaultLogger: + panic(fmt.Sprintf("default telemetry logger already set in the context: %v", v)) + case *mockLogger: + // Do nothing. Unit and integration tests set the mock logger in the context + // to avoid making actual API calls. Thus WithDefaultLogger should silently + // ignore the mock logger. + default: + panic(fmt.Sprintf("unexpected telemetry logger type: %T", v)) + } + + return ctx } +// WithMockLogger sets a mock telemetry logger in the context. It overrides the +// default logger if it is already set in the context. func WithMockLogger(ctx context.Context) context.Context { v := ctx.Value(telemetryLoggerKey) if v != nil { @@ -35,11 +52,11 @@ func fromContext(ctx context.Context) Logger { panic("telemetry logger not found in the context") } - switch v.(type) { + switch vv := v.(type) { case *defaultLogger: - return v.(*defaultLogger) + return vv case *mockLogger: - return v.(*mockLogger) + return vv default: panic(fmt.Sprintf("unexpected telemetry logger type: %T", v)) }