diff --git a/cmd/root/root.go b/cmd/root/root.go index 6bea7806..ac9116bc 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -23,8 +23,11 @@ var RootCmd = &cobra.Command{ SilenceUsage: true, PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + // Configure our user agent with the command that's about to be executed. - ctx := withCommandInUserAgent(cmd.Context(), cmd) + ctx = withCommandInUserAgent(ctx, cmd) + ctx = withUpstreamInUserAgent(ctx) cmd.SetContext(ctx) if Verbose { diff --git a/cmd/root/user_agent_upstream.go b/cmd/root/user_agent_upstream.go new file mode 100644 index 00000000..7d75650c --- /dev/null +++ b/cmd/root/user_agent_upstream.go @@ -0,0 +1,33 @@ +package root + +import ( + "context" + "os" + + "github.com/databricks/databricks-sdk-go/useragent" +) + +// Environment variables that caller can set to convey what is upstream to bricks. +const upstreamEnvVar = "BRICKS_UPSTREAM" +const upstreamVersionEnvVar = "BRICKS_UPSTREAM_VERSION" + +// Keys in the user agent. +const upstreamKey = "upstream" +const upstreamVersionKey = "upstream-version" + +func withUpstreamInUserAgent(ctx context.Context) context.Context { + value := os.Getenv(upstreamEnvVar) + if value == "" { + return ctx + } + + ctx = useragent.InContext(ctx, upstreamKey, value) + + // Include upstream version as well, if set. + value = os.Getenv(upstreamVersionEnvVar) + if value == "" { + return ctx + } + + return useragent.InContext(ctx, upstreamVersionKey, value) +} diff --git a/cmd/root/user_agent_upstream_test.go b/cmd/root/user_agent_upstream_test.go new file mode 100644 index 00000000..fc6ea0c7 --- /dev/null +++ b/cmd/root/user_agent_upstream_test.go @@ -0,0 +1,45 @@ +package root + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sdk-go/useragent" + "github.com/stretchr/testify/assert" +) + +func TestUpstreamSet(t *testing.T) { + t.Setenv(upstreamEnvVar, "foobar") + ctx := withUpstreamInUserAgent(context.Background()) + assert.Contains(t, useragent.FromContext(ctx), "upstream/foobar") +} + +func TestUpstreamSetEmpty(t *testing.T) { + t.Setenv(upstreamEnvVar, "") + ctx := withUpstreamInUserAgent(context.Background()) + assert.NotContains(t, useragent.FromContext(ctx), "upstream/") +} + +func TestUpstreamVersionSet(t *testing.T) { + t.Setenv(upstreamEnvVar, "foobar") + t.Setenv(upstreamVersionEnvVar, "0.0.1") + ctx := withUpstreamInUserAgent(context.Background()) + assert.Contains(t, useragent.FromContext(ctx), "upstream/foobar") + assert.Contains(t, useragent.FromContext(ctx), "upstream-version/0.0.1") +} + +func TestUpstreamVersionSetEmpty(t *testing.T) { + t.Setenv(upstreamEnvVar, "foobar") + t.Setenv(upstreamVersionEnvVar, "") + ctx := withUpstreamInUserAgent(context.Background()) + assert.Contains(t, useragent.FromContext(ctx), "upstream/foobar") + assert.NotContains(t, useragent.FromContext(ctx), "upstream-version/") +} + +func TestUpstreamVersionSetUpstreamNotSet(t *testing.T) { + t.Setenv(upstreamEnvVar, "") + t.Setenv(upstreamVersionEnvVar, "0.0.1") + ctx := withUpstreamInUserAgent(context.Background()) + assert.NotContains(t, useragent.FromContext(ctx), "upstream/") + assert.NotContains(t, useragent.FromContext(ctx), "upstream-version/") +}