diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index df98cc151..f47b41990 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/databricks/cli/cmd" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/auth" "github.com/databricks/cli/libs/auth/cache" "github.com/databricks/cli/libs/databrickscfg/profile" @@ -106,7 +107,7 @@ func getCobraCmdForTest(f fixtures.HTTPFixture) (*cobra.Command, *bytes.Buffer) func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) cmd.SetArgs([]string{"auth", "token", "--profile", "expired"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() assert.Empty(t, out) @@ -117,7 +118,7 @@ func TestTokenCmdWithProfilePrintsHelpfulLoginMessageOnRefreshFailure(t *testing func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) { cmd, output := getCobraCmdForTest(refreshFailureTokenResponse) cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() assert.Empty(t, out) @@ -128,7 +129,7 @@ func TestTokenCmdWithHostPrintsHelpfulLoginMessageOnRefreshFailure(t *testing.T) func TestTokenCmdInvalidResponse(t *testing.T) { cmd, output := getCobraCmdForTest(refreshFailureInvalidResponse) cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() assert.Empty(t, out) @@ -139,7 +140,7 @@ func TestTokenCmdInvalidResponse(t *testing.T) { func TestTokenCmdOtherErrorResponse(t *testing.T) { cmd, output := getCobraCmdForTest(refreshFailureOtherError) cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() assert.Empty(t, out) @@ -150,7 +151,7 @@ func TestTokenCmdOtherErrorResponse(t *testing.T) { func TestTokenCmdWithProfileSuccess(t *testing.T) { cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) cmd.SetArgs([]string{"auth", "token", "--profile", "active"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() validateToken(t, out) @@ -160,7 +161,7 @@ func TestTokenCmdWithProfileSuccess(t *testing.T) { func TestTokenCmdWithHostSuccess(t *testing.T) { cmd, output := getCobraCmdForTest(refreshSuccessTokenResponse) cmd.SetArgs([]string{"auth", "token", "--host", "https://accounts.cloud.databricks.com", "--account-id", "expired"}) - err := cmd.Execute() + err := root.Execute(cmd.Context(), cmd) out := output.String() validateToken(t, out) diff --git a/cmd/configure/configure_test.go b/cmd/configure/configure_test.go index 14eb0674a..309c65363 100644 --- a/cmd/configure/configure_test.go +++ b/cmd/configure/configure_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/databricks/cli/cmd" + "github.com/databricks/cli/cmd/root" "github.com/stretchr/testify/assert" "gopkg.in/ini.v1" ) @@ -57,7 +58,7 @@ func TestDefaultConfigureNoInteractive(t *testing.T) { cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) - err := cmd.ExecuteContext(ctx) + err := root.Execute(ctx, cmd) assert.NoError(t, err) cfgPath := filepath.Join(tempHomeDir, ".databrickscfg") @@ -91,7 +92,7 @@ func TestConfigFileFromEnvNoInteractive(t *testing.T) { cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host"}) - err := cmd.ExecuteContext(ctx) + err := root.Execute(ctx, cmd) assert.NoError(t, err) _, err = os.Stat(cfgPath) @@ -131,7 +132,7 @@ func TestEnvVarsConfigureNoInteractive(t *testing.T) { cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token"}) - err := cmd.ExecuteContext(ctx) + err := root.Execute(ctx, cmd) assert.NoError(t, err) _, err = os.Stat(cfgPath) @@ -164,7 +165,7 @@ func TestEnvVarsConfigureNoArgsNoInteractive(t *testing.T) { cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure"}) - err := cmd.ExecuteContext(ctx) + err := root.Execute(ctx, cmd) assert.NoError(t, err) _, err = os.Stat(cfgPath) @@ -193,7 +194,7 @@ func TestCustomProfileConfigureNoInteractive(t *testing.T) { cmd := cmd.New(ctx) cmd.SetArgs([]string{"configure", "--token", "--host", "https://host", "--profile", "CUSTOM"}) - err := cmd.ExecuteContext(ctx) + err := root.Execute(ctx, cmd) assert.NoError(t, err) _, err = os.Stat(cfgPath) diff --git a/cmd/root/bundle_test.go b/cmd/root/bundle_test.go index 5871b0ae9..a3f3395d8 100644 --- a/cmd/root/bundle_test.go +++ b/cmd/root/bundle_test.go @@ -213,7 +213,7 @@ func TestTargetFlagFull(t *testing.T) { cmd.SetArgs([]string{"version", "--target", "development"}) ctx := context.Background() - err := cmd.ExecuteContext(ctx) + err := Execute(ctx, cmd) assert.NoError(t, err) assert.Equal(t, "development", getTarget(cmd)) @@ -225,7 +225,7 @@ func TestTargetFlagShort(t *testing.T) { cmd.SetArgs([]string{"version", "-t", "production"}) ctx := context.Background() - err := cmd.ExecuteContext(ctx) + err := Execute(ctx, cmd) assert.NoError(t, err) assert.Equal(t, "production", getTarget(cmd)) @@ -239,7 +239,7 @@ func TestTargetEnvironmentFlag(t *testing.T) { cmd.SetArgs([]string{"version", "--environment", "development"}) ctx := context.Background() - err := cmd.ExecuteContext(ctx) + err := Execute(ctx, cmd) assert.NoError(t, err) assert.Equal(t, "development", getTarget(cmd)) diff --git a/cmd/root/root.go b/cmd/root/root.go index 04815f48b..9e3fa4d2e 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/command" "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/log" "github.com/spf13/cobra" @@ -124,6 +125,9 @@ Stack Trace: %s`, version, r, string(trace)) }() + // Set a command execution ID value in the context + ctx = command.GenerateExecId(ctx) + // Run the command cmd, err = cmd.ExecuteContextC(ctx) if err != nil && !errors.Is(err, ErrAlreadyPrinted) { diff --git a/cmd/root/user_agent_command_exec_id.go b/cmd/root/user_agent_command_exec_id.go index 3bf32b703..dd165380a 100644 --- a/cmd/root/user_agent_command_exec_id.go +++ b/cmd/root/user_agent_command_exec_id.go @@ -3,12 +3,12 @@ package root import ( "context" + "github.com/databricks/cli/libs/command" "github.com/databricks/databricks-sdk-go/useragent" - "github.com/google/uuid" ) func withCommandExecIdInUserAgent(ctx context.Context) context.Context { // A UUID that will allow us to correlate multiple API requests made by // the same CLI invocation. - return useragent.InContext(ctx, "cmd-exec-id", uuid.New().String()) + return useragent.InContext(ctx, "cmd-exec-id", command.ExecId(ctx)) } diff --git a/cmd/root/user_agent_command_exec_id_test.go b/cmd/root/user_agent_command_exec_id_test.go index 5c4365107..c3d95b4a3 100644 --- a/cmd/root/user_agent_command_exec_id_test.go +++ b/cmd/root/user_agent_command_exec_id_test.go @@ -2,25 +2,18 @@ package root import ( "context" - "regexp" "testing" + "github.com/databricks/cli/libs/command" "github.com/databricks/databricks-sdk-go/useragent" - "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestWithCommandExecIdInUserAgent(t *testing.T) { - ctx := withCommandExecIdInUserAgent(context.Background()) + ctx := command.GenerateExecId(context.Background()) + ctx = withCommandExecIdInUserAgent(ctx) - // Check that the command exec ID is in the user agent string. + // user agent should contain cmd-exec-id/ ua := useragent.FromContext(ctx) - re := regexp.MustCompile(`cmd-exec-id/([a-f0-9-]+)`) - matches := re.FindAllStringSubmatch(ua, -1) - - // Assert that we have exactly one match and that it's a valid UUID. - require.Len(t, matches, 1) - _, err := uuid.Parse(matches[0][1]) - assert.NoError(t, err) + assert.Regexp(t, `cmd-exec-id/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}`, ua) } diff --git a/libs/command/context.go b/libs/command/context.go index e87a6f353..6249f997b 100644 --- a/libs/command/context.go +++ b/libs/command/context.go @@ -1,5 +1,11 @@ package command +import ( + "context" + + "github.com/google/uuid" +) + // key is a package-local type to use for context keys. // // Using an unexported type for context keys prevents key collisions across @@ -7,6 +13,11 @@ package command type key int const ( + // execIdKey is the context key for the execution ID. + // The value of 1 is arbitrary and can be any number. + // Other keys in the same package must have different values. + execIdKey = key(1) + // configUsedKey is the context key for the auth configuration used to run the // command. configUsedKey = key(2) @@ -15,3 +26,21 @@ const ( // client that can be used to make authenticated requests. workspaceClientKey = key(3) ) + +func GenerateExecId(ctx context.Context) context.Context { + if v := ctx.Value(execIdKey); v != nil { + panic("command.SetExecId called twice on the same context") + } + return context.WithValue(ctx, execIdKey, uuid.New().String()) +} + +// ExecId returns a UUID value that is guaranteed to be the same throughout +// the lifetime of the command execution, and unique for each invocation of the +// CLI. +func ExecId(ctx context.Context) string { + v := ctx.Value(execIdKey) + if v == nil { + panic("command.ExecId called without calling command.SetExecId first") + } + return v.(string) +} diff --git a/libs/command/context_test.go b/libs/command/context_test.go new file mode 100644 index 000000000..252199d46 --- /dev/null +++ b/libs/command/context_test.go @@ -0,0 +1,50 @@ +package command + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestCommandGenerateExecIdPanics(t *testing.T) { + ctx := context.Background() + + // Set the execution ID. + ctx = GenerateExecId(ctx) + + // Expect a panic if the execution ID is set twice. + assert.Panics(t, func() { + ctx = GenerateExecId(ctx) + }) +} + +func TestCommandExecIdPanics(t *testing.T) { + ctx := context.Background() + + // Expect a panic if the execution ID is not set. + assert.Panics(t, func() { + ExecId(ctx) + }) +} + +func TestCommandGenerateExecId(t *testing.T) { + ctx := context.Background() + + // Set the execution ID. + ctx = GenerateExecId(ctx) + + // Expect no panic because the execution ID is set. + assert.NotPanics(t, func() { + ExecId(ctx) + }) + + v := ExecId(ctx) + + // Subsequent calls should return the same value. + assert.Equal(t, v, ExecId(ctx)) + + // The value should be a valid UUID. + assert.NoError(t, uuid.Validate(v)) +}