From 549e49e2ae1bec5db5d093a95ff236d270e32444 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 24 Feb 2025 18:05:19 +0100 Subject: [PATCH] cancel test --- libs/telemetry/upload.go | 4 +-- libs/telemetry/upload_test.go | 68 ++++++++++++++++++++++------------- main.go | 2 +- 3 files changed, 47 insertions(+), 27 deletions(-) diff --git a/libs/telemetry/upload.go b/libs/telemetry/upload.go index 7ecda241b..a6e8e44c6 100644 --- a/libs/telemetry/upload.go +++ b/libs/telemetry/upload.go @@ -37,7 +37,7 @@ type UploadConfig struct { // Upload reads telemetry logs from stdin and uploads them to the telemetry endpoint. // This function is always expected to be called in a separate child process from // the main CLI process. -func Upload() (*ResponseBody, error) { +func Upload(ctx context.Context) (*ResponseBody, error) { var err error b, err := io.ReadAll(os.Stdin) @@ -80,7 +80,7 @@ func Upload() (*ResponseBody, error) { } // Set a maximum total time to try telemetry uploads. - ctx, cancel := context.WithTimeout(context.Background(), maxUploadTime) + ctx, cancel := context.WithTimeout(ctx, maxUploadTime) defer cancel() resp := &ResponseBody{} diff --git a/libs/telemetry/upload_test.go b/libs/telemetry/upload_test.go index cd8809519..ec6e8fe37 100644 --- a/libs/telemetry/upload_test.go +++ b/libs/telemetry/upload_test.go @@ -1,6 +1,7 @@ package telemetry import ( + "context" "encoding/json" "os" "path/filepath" @@ -14,29 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestTelemetryUploadRetries(t *testing.T) { - server := testserver.New(t) - t.Cleanup(server.Close) - - count := 0 - server.Handle("POST", "/telemetry-ext", func(req testserver.Request) any { - count++ - if count == 1 { - return ResponseBody{ - NumProtoSuccess: 1, - } - } - if count == 2 { - return ResponseBody{ - NumProtoSuccess: 2, - } - } - return nil - }) - - t.Setenv("DATABRICKS_HOST", server.URL) - t.Setenv("DATABRICKS_TOKEN", "token") - +func configureStdin(t *testing.T) { logs := []protos.FrontendLog{ { FrontendLogEventID: uuid.New().String(), @@ -76,9 +55,50 @@ func TestTelemetryUploadRetries(t *testing.T) { f.Close() os.Stdin = old }) +} - resp, err := Upload() +func TestTelemetryUploadRetries(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + + count := 0 + server.Handle("POST", "/telemetry-ext", func(req testserver.Request) any { + count++ + if count == 1 { + return ResponseBody{ + NumProtoSuccess: 1, + } + } + if count == 2 { + return ResponseBody{ + NumProtoSuccess: 2, + } + } + return nil + }) + + t.Setenv("DATABRICKS_HOST", server.URL) + t.Setenv("DATABRICKS_TOKEN", "token") + + configureStdin(t) + + resp, err := Upload(context.Background()) require.NoError(t, err) assert.Equal(t, int64(2), resp.NumProtoSuccess) assert.Equal(t, 2, count) } + +func TestTelemetryUploadCanceled(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + configureStdin(t) + _, err := Upload(ctx) + + // Since the context is already cancelled, upload should fail immediately + // with a timeout error. + assert.ErrorContains(t, err, "Failed to flush telemetry log due to timeout") +} diff --git a/main.go b/main.go index 0a5a85eba..70ba7d4d1 100644 --- a/main.go +++ b/main.go @@ -41,7 +41,7 @@ func main() { Level: logger.LevelError, } - resp, err := telemetry.Upload() + resp, err := telemetry.Upload(ctx) if err != nil { fmt.Fprintf(errW, "error: %s\n", err) os.Exit(1)