From 46dd80d277fe6f727244f7adcd52a8ef38ef382e Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Wed, 27 Nov 2024 23:29:04 +0100 Subject: [PATCH] better unit tests --- libs/telemetry/logger.go | 4 +- libs/telemetry/logger_test.go | 133 ++++++++++++++++++++++++++-------- 2 files changed, 105 insertions(+), 32 deletions(-) diff --git a/libs/telemetry/logger.go b/libs/telemetry/logger.go index d56cc65d9..a605b4d33 100644 --- a/libs/telemetry/logger.go +++ b/libs/telemetry/logger.go @@ -125,7 +125,7 @@ func (l *logger) TrackEvent(event FrontendLogEntry) { // Maximum additional time to wait for the telemetry event to flush. We expect the flush // method to be called when the CLI command is about to exist, so this time would // be purely additive to the end user's experience. -const MaxAdditionalWaitTime = 1 * time.Second +var MaxAdditionalWaitTime = 1 * time.Second // TODO: Do not close the connection in-case of error. Allows for faster retry. // TODO: Talk about why we make only one API call at the end. It's because the @@ -152,7 +152,7 @@ func (l *logger) Flush() { } // All logs were successfully sent. No need to retry. - if len(l.protoLogs) <= int(resp.NumProtoSuccess) && len(resp.Errors) > 0 { + if len(l.protoLogs) <= int(resp.NumProtoSuccess) && len(resp.Errors) == 0 { return } diff --git a/libs/telemetry/logger_test.go b/libs/telemetry/logger_test.go index 2f3c96a1e..d13e25c01 100644 --- a/libs/telemetry/logger_test.go +++ b/libs/telemetry/logger_test.go @@ -3,35 +3,67 @@ package telemetry import ( "context" "fmt" + "io" "net/http" "testing" + "time" "github.com/stretchr/testify/assert" ) type mockDatabricksClient struct { numCalls int + + t *testing.T } -// TODO: Assert on the request body provided to this method. func (m *mockDatabricksClient) Do(ctx context.Context, method, path string, headers map[string]string, request, response any, visitors ...func(*http.Request) error) error { - // For the first two calls, we want to return an error to simulate a server - // timeout. For the third call, we want to return a successful response. m.numCalls++ + + assertRequestPayload := func() { + expectedProtoLogs := []string{ + "{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE1\"}}}", + "{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE2\"}}}", + "{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE2\"}}}", + "{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE3\"}}}", + } + + // Assert payload matches the expected payload. + assert.Equal(m.t, expectedProtoLogs, request.(RequestBody).ProtoLogs) + } + switch m.numCalls { case 1, 2: + // Assert that the request is of type *io.PipeReader, which implies that + // the request is not coming from the main thread. + assert.IsType(m.t, &io.PipeReader{}, request) + + // For the first two calls, we want to return an error to simulate a server + // timeout. return fmt.Errorf("server timeout") case 3: + // Assert that the request is of type *io.PipeReader, which implies that + // the request is not coming from the main thread. + assert.IsType(m.t, &io.PipeReader{}, request) + + // The call is successful but not all events are successfully logged. *(response.(*ResponseBody)) = ResponseBody{ NumProtoSuccess: 2, } case 4: + // Assert that the request is of type RequestBody, which implies that the + // request is coming from the main thread. + assertRequestPayload() return fmt.Errorf("some weird error") case 5: + // The call is successful but not all events are successfully logged. + assertRequestPayload() *(response.(*ResponseBody)) = ResponseBody{ NumProtoSuccess: 3, } case 6: + // The call is successful and all events are successfully logged. + assertRequestPayload() *(response.(*ResponseBody)) = ResponseBody{ NumProtoSuccess: 4, } @@ -42,11 +74,22 @@ func (m *mockDatabricksClient) Do(ctx context.Context, method, path string, head return nil } -// TODO: Run these tests multiple time to root out race conditions. -func TestTelemetryLoggerPersistentConnectionRetriesOnError(t *testing.T) { - mockClient := &mockDatabricksClient{} +// We run each of the unit tests multiple times to root out any race conditions +// that may exist. +func TestTelemetryLogger(t *testing.T) { + for i := 0; i < 5000; i++ { + t.Run("testPersistentConnectionRetriesOnError", testPersistentConnectionRetriesOnError) + t.Run("testFlush", testFlush) + t.Run("testFlushRespectsTimeout", testFlushRespectsTimeout) + } +} - ctx, _ := context.WithCancel(context.Background()) +func testPersistentConnectionRetriesOnError(t *testing.T) { + mockClient := &mockDatabricksClient{ + t: t, + } + + ctx := context.Background() l, err := NewLogger(ctx, mockClient) assert.NoError(t, err) @@ -62,36 +105,32 @@ func TestTelemetryLoggerPersistentConnectionRetriesOnError(t *testing.T) { assert.Equal(t, int64(2), resp.NumProtoSuccess) } -func TestTelemetryLogger(t *testing.T) { - mockClient := &mockDatabricksClient{} +func testFlush(t *testing.T) { + mockClient := &mockDatabricksClient{ + t: t, + } - ctx, _ := context.WithCancel(context.Background()) + ctx := context.Background() l, err := NewLogger(ctx, mockClient) assert.NoError(t, err) - // Add three events to be tracked and flushed. - l.TrackEvent(FrontendLogEntry{ - DatabricksCliLog: DatabricksCliLog{ - CliTestEvent: CliTestEvent{Name: DummyCliEnumValue1}, - }, - }) - l.TrackEvent(FrontendLogEntry{ - DatabricksCliLog: DatabricksCliLog{ - CliTestEvent: CliTestEvent{Name: DummyCliEnumValue2}, - }, - }) - l.TrackEvent(FrontendLogEntry{ - DatabricksCliLog: DatabricksCliLog{ - CliTestEvent: CliTestEvent{Name: DummyCliEnumValue2}, - }, - }) - l.TrackEvent(FrontendLogEntry{ - DatabricksCliLog: DatabricksCliLog{ - CliTestEvent: CliTestEvent{Name: DummyCliEnumValue3}, - }, + // Set the maximum additional wait time to 1 hour to ensure that the + // the Flush method does not timeout in the test run. + MaxAdditionalWaitTime = 1 * time.Hour + t.Cleanup(func() { + MaxAdditionalWaitTime = 1 * time.Second }) + // Add four events to be tracked and flushed. + for _, v := range []DummyCliEnum{DummyCliEnumValue1, DummyCliEnumValue2, DummyCliEnumValue2, DummyCliEnumValue3} { + l.TrackEvent(FrontendLogEntry{ + DatabricksCliLog: DatabricksCliLog{ + CliTestEvent: CliTestEvent{Name: v}, + }, + }) + } + // Flush the events. l.Flush() @@ -100,3 +139,37 @@ func TestTelemetryLogger(t *testing.T) { // the number of events we added. assert.Equal(t, 6, mockClient.numCalls) } + +func testFlushRespectsTimeout(t *testing.T) { + mockClient := &mockDatabricksClient{ + t: t, + } + + ctx := context.Background() + + l, err := NewLogger(ctx, mockClient) + assert.NoError(t, err) + + // Set the timer to 0 to ensure that the Flush method times out immediately. + MaxAdditionalWaitTime = 0 * time.Hour + t.Cleanup(func() { + MaxAdditionalWaitTime = 1 * time.Second + }) + + // Add four events to be tracked and flushed. + for _, v := range []DummyCliEnum{DummyCliEnumValue1, DummyCliEnumValue2, DummyCliEnumValue2, DummyCliEnumValue3} { + l.TrackEvent(FrontendLogEntry{ + DatabricksCliLog: DatabricksCliLog{ + CliTestEvent: CliTestEvent{Name: v}, + }, + }) + } + + // Flush the events. + l.Flush() + + // Assert that the .Do method was called less than or equal to 3 times. Since + // the timeout is set to 0, only the calls from the parallel go-routine should + // be made. The main thread should not make any calls. + assert.LessOrEqual(t, mockClient.numCalls, 3) +}