From ce0667219a958381a038b370d01149b0d68a93f5 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 12:08:09 +0530 Subject: [PATCH] add tests for the context methods --- libs/telemetry/context.go | 13 +++--- libs/telemetry/context_test.go | 77 ++++++++++++++++++++++++++++++++++ libs/telemetry/logger.go | 3 ++ libs/telemetry/mock_logger.go | 3 ++ 4 files changed, 89 insertions(+), 7 deletions(-) create mode 100644 libs/telemetry/context_test.go diff --git a/libs/telemetry/context.go b/libs/telemetry/context.go index 8236fde04..c0aea80bc 100644 --- a/libs/telemetry/context.go +++ b/libs/telemetry/context.go @@ -11,25 +11,24 @@ type telemetryLogger int // Key to store the telemetry logger in the context var telemetryLoggerKey telemetryLogger -// TODO: Add tests for these methods. func WithDefaultLogger(ctx context.Context) context.Context { v := ctx.Value(telemetryLoggerKey) // If no logger is set in the context, set the default logger. if v == nil { - nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{logs: []FrontendLog{}}) + nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{}) return nctx } switch v.(type) { case *defaultLogger: - panic(fmt.Sprintf("default telemetry logger already set in the context: %v", v)) + panic(fmt.Errorf("default telemetry logger already set in the context: %T", v)) case *mockLogger: // Do nothing. Unit and integration tests set the mock logger in the context // to avoid making actual API calls. Thus WithDefaultLogger should silently // ignore the mock logger. default: - panic(fmt.Sprintf("unexpected telemetry logger type: %T", v)) + panic(fmt.Errorf("unexpected telemetry logger type: %T", v)) } return ctx @@ -40,7 +39,7 @@ func WithDefaultLogger(ctx context.Context) context.Context { func WithMockLogger(ctx context.Context) context.Context { v := ctx.Value(telemetryLoggerKey) if v != nil { - panic(fmt.Sprintf("telemetry logger already set in the context: %v", v)) + panic(fmt.Errorf("telemetry logger already set in the context: %T", v)) } return context.WithValue(ctx, telemetryLoggerKey, &mockLogger{}) @@ -49,7 +48,7 @@ func WithMockLogger(ctx context.Context) context.Context { func fromContext(ctx context.Context) Logger { v := ctx.Value(telemetryLoggerKey) if v == nil { - panic("telemetry logger not found in the context") + panic(fmt.Errorf("telemetry logger not found in the context")) } switch vv := v.(type) { @@ -58,6 +57,6 @@ func fromContext(ctx context.Context) Logger { case *mockLogger: return vv default: - panic(fmt.Sprintf("unexpected telemetry logger type: %T", v)) + panic(fmt.Errorf("unexpected telemetry logger type: %T", v)) } } diff --git a/libs/telemetry/context_test.go b/libs/telemetry/context_test.go new file mode 100644 index 000000000..ddcdb83de --- /dev/null +++ b/libs/telemetry/context_test.go @@ -0,0 +1,77 @@ +package telemetry + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithDefaultLogger(t *testing.T) { + ctx := context.Background() + + // No default logger set + ctx1 := WithDefaultLogger(ctx) + assert.Equal(t, &defaultLogger{}, ctx1.Value(telemetryLoggerKey)) + + // Default logger already set + assert.PanicsWithError(t, "default telemetry logger already set in the context: *telemetry.defaultLogger", func() { + WithDefaultLogger(ctx1) + }) + + // Mock logger already set + ctx2 := WithMockLogger(ctx) + assert.NotPanics(t, func() { + WithDefaultLogger(ctx2) + }) + + // Unexpected logger type + type foobar struct{} + ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{}) + assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() { + WithDefaultLogger(ctx3) + }) +} + +func TestWithMockLogger(t *testing.T) { + ctx := context.Background() + + // No logger set + ctx1 := WithMockLogger(ctx) + assert.Equal(t, &mockLogger{}, ctx1.Value(telemetryLoggerKey)) + + // Logger already set + assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.mockLogger", func() { + WithMockLogger(ctx1) + }) + + // Default logger already set + ctx2 := WithDefaultLogger(ctx) + assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.defaultLogger", func() { + WithMockLogger(ctx2) + }) +} + +func TestFromContext(t *testing.T) { + ctx := context.Background() + + // No logger set + assert.PanicsWithError(t, "telemetry logger not found in the context", func() { + fromContext(ctx) + }) + + // Default logger set + ctx1 := WithDefaultLogger(ctx) + assert.Equal(t, &defaultLogger{}, fromContext(ctx1)) + + // Mock logger set + ctx2 := WithMockLogger(ctx) + assert.Equal(t, &mockLogger{}, fromContext(ctx2)) + + // Unexpected logger type + type foobar struct{} + ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{}) + assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() { + fromContext(ctx3) + }) +} diff --git a/libs/telemetry/logger.go b/libs/telemetry/logger.go index 43abf5dd3..bb60696a7 100644 --- a/libs/telemetry/logger.go +++ b/libs/telemetry/logger.go @@ -35,6 +35,9 @@ type defaultLogger struct { } func (l *defaultLogger) Log(event DatabricksCliLog) { + if l.logs == nil { + l.logs = make([]FrontendLog, 0) + } l.logs = append(l.logs, FrontendLog{ // The telemetry endpoint deduplicates logs based on the FrontendLogEventID. // This it's important to generate a unique ID for each log event. diff --git a/libs/telemetry/mock_logger.go b/libs/telemetry/mock_logger.go index e673b3f89..de15dd3d4 100644 --- a/libs/telemetry/mock_logger.go +++ b/libs/telemetry/mock_logger.go @@ -16,6 +16,9 @@ type mockLogger struct { } func (l *mockLogger) Log(event DatabricksCliLog) { + if l.events == nil { + l.events = make([]DatabricksCliLog, 0) + } l.events = append(l.events, event) }