Compare commits

...

3 Commits

Author SHA1 Message Date
Shreyas Goenka ce0667219a
add tests for the context methods 2025-01-03 12:08:09 +05:30
Shreyas Goenka 4832b545b0
use mock logger in integration tests 2025-01-03 11:47:45 +05:30
Shreyas Goenka 2357954885
Add type for mock logger and some cleanup 2025-01-03 11:31:22 +05:30
9 changed files with 217 additions and 55 deletions

View File

@ -201,7 +201,7 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// Configure the logger to send telemetry to Databricks.
ctx := telemetry.ContextWithLogger(cmd.Context())
ctx := telemetry.WithDefaultLogger(cmd.Context())
cmd.SetContext(ctx)
return root.MustWorkspaceClient(cmd, args)

View File

@ -29,7 +29,6 @@ const defaultSparkVersion = "13.3.x-snapshot-scala2.12"
func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any) string {
bundleRoot := t.TempDir()
ctx = telemetry.ContextWithLogger(ctx)
return initTestTemplateWithBundleRoot(t, ctx, templateName, config, bundleRoot)
}
@ -38,10 +37,10 @@ func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, te
configFilePath := writeConfigFile(t, config)
ctx = telemetry.ContextWithLogger(ctx)
ctx = root.SetWorkspaceClient(ctx, nil)
cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles")
ctx = cmdio.InContext(ctx, cmd)
ctx = telemetry.WithMockLogger(ctx)
out, err := filer.NewLocalClient(bundleRoot)
require.NoError(t, err)

View File

@ -45,8 +45,8 @@ func TestBundleInitOnMlopsStacks(t *testing.T) {
ctx, wt := acc.WorkspaceTest(t)
w := wt.W
// Configure a telemetry logger in the context.
ctx = telemetry.ContextWithLogger(ctx)
// Use mock logger to introspect the telemetry payload.
ctx = telemetry.WithMockLogger(ctx)
tmpDir1 := t.TempDir()
tmpDir2 := t.TempDir()
@ -71,10 +71,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) {
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json"))
// Assert the telemetry payload is correctly logged.
logs := telemetry.GetLogs(ctx)
require.NoError(t, err)
require.Len(t, len(logs), 1)
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
tlmyEvents := telemetry.Introspect(ctx)
require.Len(t, telemetry.Introspect(ctx), 1)
event := tlmyEvents[0].BundleInitEvent
assert.Equal(t, "mlops-stacks", event.TemplateName)
get := func(key string) string {
@ -179,8 +178,8 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) {
for _, tc := range tcases {
ctx, _ := acc.WorkspaceTest(t)
// Configure a telemetry logger in the context.
ctx = telemetry.ContextWithLogger(ctx)
// Use mock logger to introspect the telemetry payload.
ctx = telemetry.WithMockLogger(ctx)
tmpDir1 := t.TempDir()
tmpDir2 := t.TempDir()
@ -198,10 +197,9 @@ func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) {
assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"]))
// Assert the telemetry payload is correctly logged.
logs := telemetry.GetLogs(ctx)
require.NoError(t, err)
require.Len(t, len(logs), 1)
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
logs := telemetry.Introspect(ctx)
require.Len(t, logs, 1)
event := logs[0].BundleInitEvent
assert.Equal(t, event.TemplateName, tc.name)
get := func(key string) string {
@ -259,18 +257,17 @@ func TestBundleInitTelemetryForCustomTemplates(t *testing.T) {
err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644)
require.NoError(t, err)
// Configure a telemetry logger in the context.
ctx = telemetry.ContextWithLogger(ctx)
// Use mock logger to introspect the telemetry payload.
ctx = telemetry.WithMockLogger(ctx)
// Run bundle init.
testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json"))
// Assert the telemetry payload is correctly logged. For custom templates we should
// never set template_enum_args.
logs := telemetry.GetLogs(ctx)
require.NoError(t, err)
require.Len(t, len(logs), 1)
event := logs[0].Entry.DatabricksCliLog.BundleInitEvent
tlmyEvents := telemetry.Introspect(ctx)
require.Len(t, len(tlmyEvents), 1)
event := tlmyEvents[0].BundleInitEvent
assert.Equal(t, "custom", event.TemplateName)
assert.Empty(t, event.TemplateEnumArgs)

View File

@ -57,16 +57,17 @@ func TestTelemetryLogger(t *testing.T) {
},
}
assert.Len(t, reflect.TypeOf(telemetry.DatabricksCliLog{}).NumField(), len(events),
assert.Equal(t, len(events), reflect.TypeOf(telemetry.DatabricksCliLog{}).NumField(),
"Number of events should match the number of fields in DatabricksCliLog. Please add a new event to this test.")
ctx, w := acc.WorkspaceTest(t)
ctx = telemetry.ContextWithLogger(ctx)
ctx = telemetry.WithDefaultLogger(ctx)
// Extend the maximum wait time for the telemetry flush just for this test.
oldV := telemetry.MaxAdditionalWaitTime
telemetry.MaxAdditionalWaitTime = 1 * time.Hour
t.Cleanup(func() {
telemetry.MaxAdditionalWaitTime = 2 * time.Second
telemetry.MaxAdditionalWaitTime = oldV
})
for _, event := range events {

View File

@ -2,6 +2,7 @@ package telemetry
import (
"context"
"fmt"
)
// Private type to store the telemetry logger in the context
@ -10,21 +11,52 @@ type telemetryLogger int
// Key to store the telemetry logger in the context
var telemetryLoggerKey telemetryLogger
func ContextWithLogger(ctx context.Context) context.Context {
_, ok := ctx.Value(telemetryLoggerKey).(*logger)
if ok {
// If a logger is already configured in the context, do not set a new one.
// This is useful for testing.
return ctx
func WithDefaultLogger(ctx context.Context) context.Context {
v := ctx.Value(telemetryLoggerKey)
// If no logger is set in the context, set the default logger.
if v == nil {
nctx := context.WithValue(ctx, telemetryLoggerKey, &defaultLogger{})
return nctx
}
return context.WithValue(ctx, telemetryLoggerKey, &logger{logs: []FrontendLog{}})
switch v.(type) {
case *defaultLogger:
panic(fmt.Errorf("default telemetry logger already set in the context: %T", v))
case *mockLogger:
// Do nothing. Unit and integration tests set the mock logger in the context
// to avoid making actual API calls. Thus WithDefaultLogger should silently
// ignore the mock logger.
default:
panic(fmt.Errorf("unexpected telemetry logger type: %T", v))
}
return ctx
}
func fromContext(ctx context.Context) *logger {
l, ok := ctx.Value(telemetryLoggerKey).(*logger)
if !ok {
panic("telemetry logger not found in the context")
// WithMockLogger sets a mock telemetry logger in the context. It overrides the
// default logger if it is already set in the context.
func WithMockLogger(ctx context.Context) context.Context {
v := ctx.Value(telemetryLoggerKey)
if v != nil {
panic(fmt.Errorf("telemetry logger already set in the context: %T", v))
}
return context.WithValue(ctx, telemetryLoggerKey, &mockLogger{})
}
func fromContext(ctx context.Context) Logger {
v := ctx.Value(telemetryLoggerKey)
if v == nil {
panic(fmt.Errorf("telemetry logger not found in the context"))
}
switch vv := v.(type) {
case *defaultLogger:
return vv
case *mockLogger:
return vv
default:
panic(fmt.Errorf("unexpected telemetry logger type: %T", v))
}
return l
}

View File

@ -0,0 +1,77 @@
package telemetry
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithDefaultLogger(t *testing.T) {
ctx := context.Background()
// No default logger set
ctx1 := WithDefaultLogger(ctx)
assert.Equal(t, &defaultLogger{}, ctx1.Value(telemetryLoggerKey))
// Default logger already set
assert.PanicsWithError(t, "default telemetry logger already set in the context: *telemetry.defaultLogger", func() {
WithDefaultLogger(ctx1)
})
// Mock logger already set
ctx2 := WithMockLogger(ctx)
assert.NotPanics(t, func() {
WithDefaultLogger(ctx2)
})
// Unexpected logger type
type foobar struct{}
ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{})
assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() {
WithDefaultLogger(ctx3)
})
}
func TestWithMockLogger(t *testing.T) {
ctx := context.Background()
// No logger set
ctx1 := WithMockLogger(ctx)
assert.Equal(t, &mockLogger{}, ctx1.Value(telemetryLoggerKey))
// Logger already set
assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.mockLogger", func() {
WithMockLogger(ctx1)
})
// Default logger already set
ctx2 := WithDefaultLogger(ctx)
assert.PanicsWithError(t, "telemetry logger already set in the context: *telemetry.defaultLogger", func() {
WithMockLogger(ctx2)
})
}
func TestFromContext(t *testing.T) {
ctx := context.Background()
// No logger set
assert.PanicsWithError(t, "telemetry logger not found in the context", func() {
fromContext(ctx)
})
// Default logger set
ctx1 := WithDefaultLogger(ctx)
assert.Equal(t, &defaultLogger{}, fromContext(ctx1))
// Mock logger set
ctx2 := WithMockLogger(ctx)
assert.Equal(t, &mockLogger{}, fromContext(ctx2))
// Unexpected logger type
type foobar struct{}
ctx3 := context.WithValue(ctx, telemetryLoggerKey, &foobar{})
assert.PanicsWithError(t, "unexpected telemetry logger type: *telemetry.foobar", func() {
fromContext(ctx3)
})
}

View File

@ -17,9 +17,27 @@ type DatabricksApiClient interface {
visitors ...func(*http.Request) error) error
}
func Log(ctx context.Context, event DatabricksCliLog) {
l := fromContext(ctx)
type Logger interface {
// Record a telemetry event, to be flushed later.
Log(event DatabricksCliLog)
// Flush all the telemetry events that have been logged so far. We expect
// this to be called once per CLI command for the default logger.
Flush(ctx context.Context, apiClient DatabricksApiClient)
// This function is meant to be only to be used in tests to introspect
// the telemetry logs that have been logged so far.
Introspect() []DatabricksCliLog
}
type defaultLogger struct {
logs []FrontendLog
}
func (l *defaultLogger) Log(event DatabricksCliLog) {
if l.logs == nil {
l.logs = make([]FrontendLog, 0)
}
l.logs = append(l.logs, FrontendLog{
// The telemetry endpoint deduplicates logs based on the FrontendLogEventID.
// This it's important to generate a unique ID for each log event.
@ -30,17 +48,6 @@ func Log(ctx context.Context, event DatabricksCliLog) {
})
}
type logger struct {
logs []FrontendLog
}
// This function is meant to be only to be used in tests to introspect the telemetry logs
// that have been logged so far.
func GetLogs(ctx context.Context) []FrontendLog {
l := fromContext(ctx)
return l.logs
}
// Maximum additional time to wait for the telemetry event to flush. We expect the flush
// method to be called when the CLI command is about to exist, so this caps the maximum
// additional time the user will experience because of us logging CLI telemetry.
@ -50,11 +57,10 @@ var MaxAdditionalWaitTime = 3 * time.Second
// right about as the CLI command is about to exit. The API endpoint can handle
// payloads with ~1000 events easily. Thus we log all the events at once instead of
// batching the logs across multiple API calls.
func Flush(ctx context.Context, apiClient DatabricksApiClient) {
func (l *defaultLogger) Flush(ctx context.Context, apiClient DatabricksApiClient) {
// Set a maximum time to wait for the telemetry event to flush.
ctx, cancel := context.WithTimeout(ctx, MaxAdditionalWaitTime)
defer cancel()
l := fromContext(ctx)
if len(l.logs) == 0 {
log.Debugf(ctx, "No telemetry events to flush")
@ -112,3 +118,22 @@ func Flush(ctx context.Context, apiClient DatabricksApiClient) {
return
}
}
func (l *defaultLogger) Introspect() []DatabricksCliLog {
panic("not implemented")
}
func Log(ctx context.Context, event DatabricksCliLog) {
l := fromContext(ctx)
l.Log(event)
}
func Flush(ctx context.Context, apiClient DatabricksApiClient) {
l := fromContext(ctx)
l.Flush(ctx, apiClient)
}
func Introspect(ctx context.Context) []DatabricksCliLog {
l := fromContext(ctx)
return l.Introspect()
}

View File

@ -5,7 +5,6 @@ import (
"math/rand"
"net/http"
"testing"
"time"
"github.com/databricks/cli/libs/telemetry/events"
"github.com/google/uuid"
@ -64,7 +63,7 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) {
uuid.SetRand(nil)
})
ctx := ContextWithLogger(context.Background())
ctx := WithDefaultLogger(context.Background())
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
Log(ctx, DatabricksCliLog{
@ -82,9 +81,10 @@ func TestTelemetryLoggerFlushesEvents(t *testing.T) {
func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) {
// Set the maximum additional wait time to 0 to ensure that the Flush method times out immediately.
oldV := MaxAdditionalWaitTime
MaxAdditionalWaitTime = 0
t.Cleanup(func() {
MaxAdditionalWaitTime = 2 * time.Second
MaxAdditionalWaitTime = oldV
})
mockClient := &mockDatabricksClient{
@ -97,7 +97,7 @@ func TestTelemetryLoggerFlushExitsOnTimeout(t *testing.T) {
uuid.SetRand(nil)
})
ctx := ContextWithLogger(context.Background())
ctx := WithDefaultLogger(context.Background())
for _, v := range []events.DummyCliEnum{events.DummyCliEnumValue1, events.DummyCliEnumValue2, events.DummyCliEnumValue2, events.DummyCliEnumValue3} {
Log(ctx, DatabricksCliLog{

View File

@ -0,0 +1,31 @@
package telemetry
import "context"
// TODO CONTINUE:
// 1. Continue cleaning up the telemetry PR. Cleanup the interfaces
// and add a good mock / testing support by storing this in the context.
//
// 2. Test the logging is being done correctly. All componets work fine.
//
// 3. Ask once more for review. Also announce plans to do this by separately
// spawning a new process. We can add a new CLI command in the executable to
// do so.
type mockLogger struct {
events []DatabricksCliLog
}
func (l *mockLogger) Log(event DatabricksCliLog) {
if l.events == nil {
l.events = make([]DatabricksCliLog, 0)
}
l.events = append(l.events, event)
}
func (l *mockLogger) Flush(ctx context.Context, apiClient DatabricksApiClient) {
// Do nothing
}
func (l *mockLogger) Introspect() []DatabricksCliLog {
return l.events
}