better unit tests

This commit is contained in:
Shreyas Goenka 2024-11-27 23:29:04 +01:00
parent f9ed0d5655
commit 46dd80d277
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
2 changed files with 105 additions and 32 deletions

View File

@ -125,7 +125,7 @@ func (l *logger) TrackEvent(event FrontendLogEntry) {
// Maximum additional time to wait for the telemetry event to flush. We expect the flush
// method to be called when the CLI command is about to exist, so this time would
// be purely additive to the end user's experience.
const MaxAdditionalWaitTime = 1 * time.Second
var MaxAdditionalWaitTime = 1 * time.Second
// TODO: Do not close the connection in-case of error. Allows for faster retry.
// TODO: Talk about why we make only one API call at the end. It's because the
@ -152,7 +152,7 @@ func (l *logger) Flush() {
}
// All logs were successfully sent. No need to retry.
if len(l.protoLogs) <= int(resp.NumProtoSuccess) && len(resp.Errors) > 0 {
if len(l.protoLogs) <= int(resp.NumProtoSuccess) && len(resp.Errors) == 0 {
return
}

View File

@ -3,35 +3,67 @@ package telemetry
import (
"context"
"fmt"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type mockDatabricksClient struct {
numCalls int
t *testing.T
}
// TODO: Assert on the request body provided to this method.
func (m *mockDatabricksClient) Do(ctx context.Context, method, path string, headers map[string]string, request, response any, visitors ...func(*http.Request) error) error {
// For the first two calls, we want to return an error to simulate a server
// timeout. For the third call, we want to return a successful response.
m.numCalls++
assertRequestPayload := func() {
expectedProtoLogs := []string{
"{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE1\"}}}",
"{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE2\"}}}",
"{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE2\"}}}",
"{\"databricks_cli_log\":{\"cli_test_event\":{\"name\":\"VALUE3\"}}}",
}
// Assert payload matches the expected payload.
assert.Equal(m.t, expectedProtoLogs, request.(RequestBody).ProtoLogs)
}
switch m.numCalls {
case 1, 2:
// Assert that the request is of type *io.PipeReader, which implies that
// the request is not coming from the main thread.
assert.IsType(m.t, &io.PipeReader{}, request)
// For the first two calls, we want to return an error to simulate a server
// timeout.
return fmt.Errorf("server timeout")
case 3:
// Assert that the request is of type *io.PipeReader, which implies that
// the request is not coming from the main thread.
assert.IsType(m.t, &io.PipeReader{}, request)
// The call is successful but not all events are successfully logged.
*(response.(*ResponseBody)) = ResponseBody{
NumProtoSuccess: 2,
}
case 4:
// Assert that the request is of type RequestBody, which implies that the
// request is coming from the main thread.
assertRequestPayload()
return fmt.Errorf("some weird error")
case 5:
// The call is successful but not all events are successfully logged.
assertRequestPayload()
*(response.(*ResponseBody)) = ResponseBody{
NumProtoSuccess: 3,
}
case 6:
// The call is successful and all events are successfully logged.
assertRequestPayload()
*(response.(*ResponseBody)) = ResponseBody{
NumProtoSuccess: 4,
}
@ -42,11 +74,22 @@ func (m *mockDatabricksClient) Do(ctx context.Context, method, path string, head
return nil
}
// TODO: Run these tests multiple time to root out race conditions.
func TestTelemetryLoggerPersistentConnectionRetriesOnError(t *testing.T) {
mockClient := &mockDatabricksClient{}
// We run each of the unit tests multiple times to root out any race conditions
// that may exist.
func TestTelemetryLogger(t *testing.T) {
for i := 0; i < 5000; i++ {
t.Run("testPersistentConnectionRetriesOnError", testPersistentConnectionRetriesOnError)
t.Run("testFlush", testFlush)
t.Run("testFlushRespectsTimeout", testFlushRespectsTimeout)
}
}
ctx, _ := context.WithCancel(context.Background())
func testPersistentConnectionRetriesOnError(t *testing.T) {
mockClient := &mockDatabricksClient{
t: t,
}
ctx := context.Background()
l, err := NewLogger(ctx, mockClient)
assert.NoError(t, err)
@ -62,36 +105,32 @@ func TestTelemetryLoggerPersistentConnectionRetriesOnError(t *testing.T) {
assert.Equal(t, int64(2), resp.NumProtoSuccess)
}
func TestTelemetryLogger(t *testing.T) {
mockClient := &mockDatabricksClient{}
func testFlush(t *testing.T) {
mockClient := &mockDatabricksClient{
t: t,
}
ctx, _ := context.WithCancel(context.Background())
ctx := context.Background()
l, err := NewLogger(ctx, mockClient)
assert.NoError(t, err)
// Add three events to be tracked and flushed.
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: DummyCliEnumValue1},
},
})
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: DummyCliEnumValue2},
},
})
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: DummyCliEnumValue2},
},
})
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: DummyCliEnumValue3},
},
// Set the maximum additional wait time to 1 hour to ensure that the
// the Flush method does not timeout in the test run.
MaxAdditionalWaitTime = 1 * time.Hour
t.Cleanup(func() {
MaxAdditionalWaitTime = 1 * time.Second
})
// Add four events to be tracked and flushed.
for _, v := range []DummyCliEnum{DummyCliEnumValue1, DummyCliEnumValue2, DummyCliEnumValue2, DummyCliEnumValue3} {
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: v},
},
})
}
// Flush the events.
l.Flush()
@ -100,3 +139,37 @@ func TestTelemetryLogger(t *testing.T) {
// the number of events we added.
assert.Equal(t, 6, mockClient.numCalls)
}
func testFlushRespectsTimeout(t *testing.T) {
mockClient := &mockDatabricksClient{
t: t,
}
ctx := context.Background()
l, err := NewLogger(ctx, mockClient)
assert.NoError(t, err)
// Set the timer to 0 to ensure that the Flush method times out immediately.
MaxAdditionalWaitTime = 0 * time.Hour
t.Cleanup(func() {
MaxAdditionalWaitTime = 1 * time.Second
})
// Add four events to be tracked and flushed.
for _, v := range []DummyCliEnum{DummyCliEnumValue1, DummyCliEnumValue2, DummyCliEnumValue2, DummyCliEnumValue3} {
l.TrackEvent(FrontendLogEntry{
DatabricksCliLog: DatabricksCliLog{
CliTestEvent: CliTestEvent{Name: v},
},
})
}
// Flush the events.
l.Flush()
// Assert that the .Do method was called less than or equal to 3 times. Since
// the timeout is set to 0, only the calls from the parallel go-routine should
// be made. The main thread should not make any calls.
assert.LessOrEqual(t, mockClient.numCalls, 3)
}