diff --git a/bundle/config/environment.go b/bundle/config/environment.go index 3e66977e..9d8c0282 100644 --- a/bundle/config/environment.go +++ b/bundle/config/environment.go @@ -29,6 +29,8 @@ type Environment struct { // Does not permit defining new variables or redefining existing ones // in the scope of an environment Variables map[string]string `json:"variables,omitempty"` + + Git Git `json:"git,omitempty"` } const ( diff --git a/bundle/config/mutator/process_environment_mode.go b/bundle/config/mutator/process_environment_mode.go index 1186b934..33f4eec8 100644 --- a/bundle/config/mutator/process_environment_mode.go +++ b/bundle/config/mutator/process_environment_mode.go @@ -14,10 +14,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/ml" ) -type processEnvironmentMode struct { - // getPrincipalGetByIdImpl overrides the GetPrincipalGetById implementation for testing purposes. - getPrincipalGetByIdImpl func(ctx context.Context, id string) (*iam.ServicePrincipal, error) -} +type processEnvironmentMode struct {} const developmentConcurrentRuns = 4 @@ -97,7 +94,7 @@ func isUserSpecificDeployment(b *bundle.Bundle) bool { !strings.Contains(b.Config.Workspace.FilesPath, username) } -func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b *bundle.Bundle) error { +func validateProductionMode(ctx context.Context, b *bundle.Bundle, isPrincipalUsed bool) error { if b.Config.Bundle.Git.Inferred { TODO: show a nice human error here? :( return fmt.Errorf("environment with 'mode: production' must specify an explicit 'git' configuration") @@ -110,12 +107,7 @@ func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b * } } - isPrincipal, err := m.isServicePrincipalUsed(ctx, b) - if err != nil { - return err - } - - if !isPrincipal { + if !isPrincipalUsed { if isUserSpecificDeployment(b) { return fmt.Errorf("environment with 'mode: development' must deploy to a location specific to the user, and should e.g. set 'root_path: ~/.bundle/${bundle.name}/${bundle.environment}'") } @@ -128,15 +120,10 @@ func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b * } // Determines whether a service principal identity is used to run the CLI. -func (m *processEnvironmentMode) isServicePrincipalUsed(ctx context.Context, b *bundle.Bundle) (bool, error) { +func isServicePrincipalUsed(ctx context.Context, b *bundle.Bundle) (bool, error) { ws := b.WorkspaceClient() - getPrincipalById := m.getPrincipalGetByIdImpl - if getPrincipalById == nil { - getPrincipalById = ws.ServicePrincipals.GetById - } - - _, err := getPrincipalById(ctx, b.Config.Workspace.CurrentUser.Id) + _, err := ws.ServicePrincipals.GetById(ctx, b.Config.Workspace.CurrentUser.Id) if err != nil { apiError, ok := err.(*apierr.APIError) if ok && apiError.StatusCode == 404 { @@ -179,7 +166,11 @@ func (m *processEnvironmentMode) Apply(ctx context.Context, b *bundle.Bundle) er } return transformDevelopmentMode(b) case config.Production: - return m.validateProductionMode(ctx, b) + isPrincipal, err := m.isServicePrincipalUsed(ctx, b) + if err != nil { + return err + } + return validateProductionMode(ctx, b, isPrincipal) case "": // No action default: diff --git a/bundle/config/mutator/process_environment_mode_test.go b/bundle/config/mutator/process_environment_mode_test.go index 6f5590d6..f48b3d17 100644 --- a/bundle/config/mutator/process_environment_mode_test.go +++ b/bundle/config/mutator/process_environment_mode_test.go @@ -9,7 +9,6 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" - "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" @@ -87,12 +86,8 @@ func TestProcessEnvironmentModeProduction(t *testing.T) { bundle.Config.Workspace.ArtifactsPath = "/Shared/.bundle/x/y/artifacts" bundle.Config.Workspace.FilesPath = "/Shared/.bundle/x/y/files" - m := ProcessEnvironmentMode() - m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) { - return nil, &apierr.APIError{StatusCode: 404} - } + err := validateProductionMode(context.Background(), bundle, false) - err := m.Apply(context.Background(), bundle) require.NoError(t, err) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) @@ -102,24 +97,16 @@ func TestProcessEnvironmentModeProduction(t *testing.T) { func TestProcessEnvironmentModeProductionFails(t *testing.T) { bundle := mockBundle(config.Production) - m := ProcessEnvironmentMode() - m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) { - return nil, &apierr.APIError{StatusCode: 404} - } + err := validateProductionMode(context.Background(), bundle, false) - err := m.Apply(context.Background(), bundle) require.Error(t, err) } func TestProcessEnvironmentModeProductionOkForPrincipal(t *testing.T) { bundle := mockBundle(config.Production) - m := ProcessEnvironmentMode() - m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) { - return nil, nil - } + err := validateProductionMode(context.Background(), bundle, false) - err := m.Apply(context.Background(), bundle) require.NoError(t, err) } diff --git a/bundle/config/root.go b/bundle/config/root.go index f1ae9c0a..13f5e486 100644 --- a/bundle/config/root.go +++ b/bundle/config/root.go @@ -198,5 +198,16 @@ func (r *Root) MergeEnvironment(env *Environment) error { r.Bundle.ComputeID = env.ComputeID } + if env.Git.Branch != "" { + r.Bundle.Git.Branch = env.Git.Branch + r.Bundle.Git.Inferred = false + } + if env.Git.Commit != "" { + r.Bundle.Git.Commit = env.Git.Commit + } + if env.Git.OriginURL != "" { + r.Bundle.Git.OriginURL = env.Git.OriginURL + } + return nil }