mirror of https://github.com/databricks/cli.git
add tests for the context methods
This commit is contained in:
parent
4832b545b0
commit
ce0667219a
|
@ -11,25 +11,24 @@ type telemetryLogger int
|
||||||
// Key to store the telemetry logger in the context
|
// Key to store the telemetry logger in the context
|
||||||
var telemetryLoggerKey telemetryLogger
|
var telemetryLoggerKey telemetryLogger
|
||||||
|
|
||||||
// TODO: Add tests for these methods.
|
|
||||||
func WithDefaultLogger(ctx context.Context) context.Context {
|
func WithDefaultLogger(ctx context.Context) context.Context {
|
||||||
v := ctx.Value(telemetryLoggerKey)
|
v := ctx.Value(telemetryLoggerKey)
|
||||||
|
|
||||||
// If no logger is set in the context, set the default logger.
|
// If no logger is set in the context, set the default logger.
|
||||||
if v == nil {
|
if v == nil {
|
||||||
nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{logs: []FrontendLog{}})
|
nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{})
|
||||||
return nctx
|
return nctx
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v.(type) {
|
switch v.(type) {
|
||||||
case *defaultLogger:
|
case *defaultLogger:
|
||||||
panic(fmt.Sprintf("default telemetry logger already set in the context: %v", v))
|
panic(fmt.Errorf("default telemetry logger already set in the context: %T", v))
|
||||||
case *mockLogger:
|
case *mockLogger:
|
||||||
// Do nothing. Unit and integration tests set the mock logger in the context
|
// Do nothing. Unit and integration tests set the mock logger in the context
|
||||||
// to avoid making actual API calls. Thus WithDefaultLogger should silently
|
// to avoid making actual API calls. Thus WithDefaultLogger should silently
|
||||||
// ignore the mock logger.
|
// ignore the mock logger.
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unexpected telemetry logger type: %T", v))
|
panic(fmt.Errorf("unexpected telemetry logger type: %T", v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
|
@ -40,7 +39,7 @@ func WithDefaultLogger(ctx context.Context) context.Context {
|
||||||
func WithMockLogger(ctx context.Context) context.Context {
|
func WithMockLogger(ctx context.Context) context.Context {
|
||||||
v := ctx.Value(telemetryLoggerKey)
|
v := ctx.Value(telemetryLoggerKey)
|
||||||
if v != nil {
|
if v != nil {
|
||||||
panic(fmt.Sprintf("telemetry logger already set in the context: %v", v))
|
panic(fmt.Errorf("telemetry logger already set in the context: %T", v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.WithValue(ctx, telemetryLoggerKey, &mockLogger{})
|
return context.WithValue(ctx, telemetryLoggerKey, &mockLogger{})
|
||||||
|
@ -49,7 +48,7 @@ func WithMockLogger(ctx context.Context) context.Context {
|
||||||
func fromContext(ctx context.Context) Logger {
|
func fromContext(ctx context.Context) Logger {
|
||||||
v := ctx.Value(telemetryLoggerKey)
|
v := ctx.Value(telemetryLoggerKey)
|
||||||
if v == nil {
|
if v == nil {
|
||||||
panic("telemetry logger not found in the context")
|
panic(fmt.Errorf("telemetry logger not found in the context"))
|
||||||
}
|
}
|
||||||
|
|
||||||
switch vv := v.(type) {
|
switch vv := v.(type) {
|
||||||
|
@ -58,6 +57,6 @@ func fromContext(ctx context.Context) Logger {
|
||||||
case *mockLogger:
|
case *mockLogger:
|
||||||
return vv
|
return vv
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unexpected telemetry logger type: %T", v))
|
panic(fmt.Errorf("unexpected telemetry logger type: %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
package telemetry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithDefaultLogger(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// No default logger set
|
||||||
|
ctx1 := WithDefaultLogger(ctx)
|
||||||
|
assert.Equal(t, &defaultLogger{}, ctx1.Value(telemetryLoggerKey))
|
||||||
|
|
||||||
|
// Default logger already set
|
||||||
|
assert.PanicsWithError(t, "default telemetry logger already set in the context: *telemetry.defaultLogger", func() {
|
||||||
|
WithDefaultLogger(ctx1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock logger already set
|
||||||
|
ctx2 := WithMockLogger(ctx)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
WithDefaultLogger(ctx2)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Unexpected logger type
|
||||||
|
type foobar struct{}
|
||||||
|
ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{})
|
||||||
|
assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() {
|
||||||
|
WithDefaultLogger(ctx3)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithMockLogger(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// No logger set
|
||||||
|
ctx1 := WithMockLogger(ctx)
|
||||||
|
assert.Equal(t, &mockLogger{}, ctx1.Value(telemetryLoggerKey))
|
||||||
|
|
||||||
|
// Logger already set
|
||||||
|
assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.mockLogger", func() {
|
||||||
|
WithMockLogger(ctx1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Default logger already set
|
||||||
|
ctx2 := WithDefaultLogger(ctx)
|
||||||
|
assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.defaultLogger", func() {
|
||||||
|
WithMockLogger(ctx2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromContext(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// No logger set
|
||||||
|
assert.PanicsWithError(t, "telemetry logger not found in the context", func() {
|
||||||
|
fromContext(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Default logger set
|
||||||
|
ctx1 := WithDefaultLogger(ctx)
|
||||||
|
assert.Equal(t, &defaultLogger{}, fromContext(ctx1))
|
||||||
|
|
||||||
|
// Mock logger set
|
||||||
|
ctx2 := WithMockLogger(ctx)
|
||||||
|
assert.Equal(t, &mockLogger{}, fromContext(ctx2))
|
||||||
|
|
||||||
|
// Unexpected logger type
|
||||||
|
type foobar struct{}
|
||||||
|
ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{})
|
||||||
|
assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() {
|
||||||
|
fromContext(ctx3)
|
||||||
|
})
|
||||||
|
}
|
|
@ -35,6 +35,9 @@ type defaultLogger struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *defaultLogger) Log(event DatabricksCliLog) {
|
func (l *defaultLogger) Log(event DatabricksCliLog) {
|
||||||
|
if l.logs == nil {
|
||||||
|
l.logs = make([]FrontendLog, 0)
|
||||||
|
}
|
||||||
l.logs = append(l.logs, FrontendLog{
|
l.logs = append(l.logs, FrontendLog{
|
||||||
// The telemetry endpoint deduplicates logs based on the FrontendLogEventID.
|
// The telemetry endpoint deduplicates logs based on the FrontendLogEventID.
|
||||||
// This it's important to generate a unique ID for each log event.
|
// This it's important to generate a unique ID for each log event.
|
||||||
|
|
|
@ -16,6 +16,9 @@ type mockLogger struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *mockLogger) Log(event DatabricksCliLog) {
|
func (l *mockLogger) Log(event DatabricksCliLog) {
|
||||||
|
if l.events == nil {
|
||||||
|
l.events = make([]DatabricksCliLog, 0)
|
||||||
|
}
|
||||||
l.events = append(l.events, event)
|
l.events = append(l.events, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue