This commit is contained in:
Lennart Kats 2023-07-10 09:21:14 +02:00
parent 368796ba12
commit f59e73a026
4 changed files with 26 additions and 35 deletions

View File

@ -29,6 +29,8 @@ type Environment struct {
// Does not permit defining new variables or redefining existing ones // Does not permit defining new variables or redefining existing ones
// in the scope of an environment // in the scope of an environment
Variables map[string]string `json:"variables,omitempty"` Variables map[string]string `json:"variables,omitempty"`
Git Git `json:"git,omitempty"`
} }
const ( const (

View File

@ -14,10 +14,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
) )
type processEnvironmentMode struct { type processEnvironmentMode struct {}
// getPrincipalGetByIdImpl overrides the GetPrincipalGetById implementation for testing purposes.
getPrincipalGetByIdImpl func(ctx context.Context, id string) (*iam.ServicePrincipal, error)
}
const developmentConcurrentRuns = 4 const developmentConcurrentRuns = 4
@ -97,7 +94,7 @@ func isUserSpecificDeployment(b *bundle.Bundle) bool {
!strings.Contains(b.Config.Workspace.FilesPath, username) !strings.Contains(b.Config.Workspace.FilesPath, username)
} }
func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b *bundle.Bundle) error { func validateProductionMode(ctx context.Context, b *bundle.Bundle, isPrincipalUsed bool) error {
if b.Config.Bundle.Git.Inferred { if b.Config.Bundle.Git.Inferred {
TODO: show a nice human error here? :( TODO: show a nice human error here? :(
return fmt.Errorf("environment with 'mode: production' must specify an explicit 'git' configuration") return fmt.Errorf("environment with 'mode: production' must specify an explicit 'git' configuration")
@ -110,12 +107,7 @@ func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b *
} }
} }
isPrincipal, err := m.isServicePrincipalUsed(ctx, b) if !isPrincipalUsed {
if err != nil {
return err
}
if !isPrincipal {
if isUserSpecificDeployment(b) { if isUserSpecificDeployment(b) {
return fmt.Errorf("environment with 'mode: development' must deploy to a location specific to the user, and should e.g. set 'root_path: ~/.bundle/${bundle.name}/${bundle.environment}'") return fmt.Errorf("environment with 'mode: development' must deploy to a location specific to the user, and should e.g. set 'root_path: ~/.bundle/${bundle.name}/${bundle.environment}'")
} }
@ -128,15 +120,10 @@ func (m *processEnvironmentMode) validateProductionMode(ctx context.Context, b *
} }
// Determines whether a service principal identity is used to run the CLI. // Determines whether a service principal identity is used to run the CLI.
func (m *processEnvironmentMode) isServicePrincipalUsed(ctx context.Context, b *bundle.Bundle) (bool, error) { func isServicePrincipalUsed(ctx context.Context, b *bundle.Bundle) (bool, error) {
ws := b.WorkspaceClient() ws := b.WorkspaceClient()
getPrincipalById := m.getPrincipalGetByIdImpl _, err := ws.ServicePrincipals.GetById(ctx, b.Config.Workspace.CurrentUser.Id)
if getPrincipalById == nil {
getPrincipalById = ws.ServicePrincipals.GetById
}
_, err := getPrincipalById(ctx, b.Config.Workspace.CurrentUser.Id)
if err != nil { if err != nil {
apiError, ok := err.(*apierr.APIError) apiError, ok := err.(*apierr.APIError)
if ok && apiError.StatusCode == 404 { if ok && apiError.StatusCode == 404 {
@ -179,7 +166,11 @@ func (m *processEnvironmentMode) Apply(ctx context.Context, b *bundle.Bundle) er
} }
return transformDevelopmentMode(b) return transformDevelopmentMode(b)
case config.Production: case config.Production:
return m.validateProductionMode(ctx, b) isPrincipal, err := m.isServicePrincipalUsed(ctx, b)
if err != nil {
return err
}
return validateProductionMode(ctx, b, isPrincipal)
case "": case "":
// No action // No action
default: default:

View File

@ -9,7 +9,6 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
@ -87,12 +86,8 @@ func TestProcessEnvironmentModeProduction(t *testing.T) {
bundle.Config.Workspace.ArtifactsPath = "/Shared/.bundle/x/y/artifacts" bundle.Config.Workspace.ArtifactsPath = "/Shared/.bundle/x/y/artifacts"
bundle.Config.Workspace.FilesPath = "/Shared/.bundle/x/y/files" bundle.Config.Workspace.FilesPath = "/Shared/.bundle/x/y/files"
m := ProcessEnvironmentMode() err := validateProductionMode(context.Background(), bundle, false)
m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) {
return nil, &apierr.APIError{StatusCode: 404}
}
err := m.Apply(context.Background(), bundle)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
@ -102,24 +97,16 @@ func TestProcessEnvironmentModeProduction(t *testing.T) {
func TestProcessEnvironmentModeProductionFails(t *testing.T) { func TestProcessEnvironmentModeProductionFails(t *testing.T) {
bundle := mockBundle(config.Production) bundle := mockBundle(config.Production)
m := ProcessEnvironmentMode() err := validateProductionMode(context.Background(), bundle, false)
m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) {
return nil, &apierr.APIError{StatusCode: 404}
}
err := m.Apply(context.Background(), bundle)
require.Error(t, err) require.Error(t, err)
} }
func TestProcessEnvironmentModeProductionOkForPrincipal(t *testing.T) { func TestProcessEnvironmentModeProductionOkForPrincipal(t *testing.T) {
bundle := mockBundle(config.Production) bundle := mockBundle(config.Production)
m := ProcessEnvironmentMode() err := validateProductionMode(context.Background(), bundle, false)
m.getPrincipalGetByIdImpl = func(ctx context.Context, id string) (*iam.ServicePrincipal, error) {
return nil, nil
}
err := m.Apply(context.Background(), bundle)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -198,5 +198,16 @@ func (r *Root) MergeEnvironment(env *Environment) error {
r.Bundle.ComputeID = env.ComputeID r.Bundle.ComputeID = env.ComputeID
} }
if env.Git.Branch != "" {
r.Bundle.Git.Branch = env.Git.Branch
r.Bundle.Git.Inferred = false
}
if env.Git.Commit != "" {
r.Bundle.Git.Commit = env.Git.Commit
}
if env.Git.OriginURL != "" {
r.Bundle.Git.OriginURL = env.Git.OriginURL
}
return nil return nil
} }