diff --git a/Makefile b/Makefile index 6067d45b..3c55b8cf 100644 --- a/Makefile +++ b/Makefile @@ -30,4 +30,5 @@ vendor: @echo "✓ Filling vendor folder with library code ..." @go mod vendor -.PHONY: build vendor coverage test lint fmt \ No newline at end of file +.PHONY: build vendor coverage test lint fmt + diff --git a/bundle/config/mutator/process_target_mode.go b/bundle/config/mutator/process_target_mode.go index 2f80fe3b..c11bd1c5 100644 --- a/bundle/config/mutator/process_target_mode.go +++ b/bundle/config/mutator/process_target_mode.go @@ -87,6 +87,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error { // (model serving doesn't yet support tags) } + for i := range r.RegisteredModels { + prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_" + r.RegisteredModels[i].Name = prefix + r.RegisteredModels[i].Name + // (registered models in Unity Catalog don't yet support tags) + } + return nil } diff --git a/bundle/config/mutator/process_target_mode_test.go b/bundle/config/mutator/process_target_mode_test.go index a0b2bac8..a9da0b0f 100644 --- a/bundle/config/mutator/process_target_mode_test.go +++ b/bundle/config/mutator/process_target_mode_test.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/libs/tags" sdkconfig "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" @@ -59,6 +60,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle { ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ "servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}}, }, + RegisteredModels: map[string]*resources.RegisteredModel{ + "registeredmodel1": {CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{Name: "registeredmodel1"}}, + }, }, }, // Use AWS implementation for testing. @@ -86,6 +90,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) { // Experiment 1 assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name) assert.Contains(t, bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags, ml.ExperimentTag{Key: "dev", Value: "lennart"}) + assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) // Experiment 2 assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name) @@ -96,7 +101,9 @@ func TestProcessTargetModeDevelopment(t *testing.T) { // Model serving endpoint 1 assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) - assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) + + // Registered model 1 + assert.Equal(t, "dev_lennart_registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name) } func TestProcessTargetModeDevelopmentTagNormalizationForAws(t *testing.T) { @@ -151,6 +158,7 @@ func TestProcessTargetModeDefault(t *testing.T) { assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) + assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name) } func TestProcessTargetModeProduction(t *testing.T) { @@ -187,6 +195,7 @@ func TestProcessTargetModeProduction(t *testing.T) { assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) + assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name) } func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) { diff --git a/bundle/config/resources.go b/bundle/config/resources.go index ad1d6e9a..2b453c66 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -14,6 +14,7 @@ type Resources struct { Models map[string]*resources.MlflowModel `json:"models,omitempty"` Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"` + RegisteredModels map[string]*resources.RegisteredModel `json:"registered_models,omitempty"` } type UniqueResourceIdTracker struct { @@ -107,6 +108,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker, tracker.Type[k] = "model_serving_endpoint" tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath } + for k := range r.RegisteredModels { + if _, ok := tracker.Type[k]; ok { + return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)", + k, + tracker.Type[k], + tracker.ConfigPath[k], + "registered_model", + r.RegisteredModels[k].ConfigFilePath, + ) + } + tracker.Type[k] = "registered_model" + tracker.ConfigPath[k] = r.RegisteredModels[k].ConfigFilePath + } return tracker, nil } @@ -129,6 +143,9 @@ func (r *Resources) SetConfigFilePath(path string) { for _, e := range r.ModelServingEndpoints { e.ConfigFilePath = path } + for _, e := range r.RegisteredModels { + e.ConfigFilePath = path + } } // Merge iterates over all resources and merges chunks of the diff --git a/bundle/config/resources/grant.go b/bundle/config/resources/grant.go new file mode 100644 index 00000000..f0ecd876 --- /dev/null +++ b/bundle/config/resources/grant.go @@ -0,0 +1,9 @@ +package resources + +// Grant holds the grant level settings for a single principal in Unity Catalog. +// Multiple of these can be defined on any Unity Catalog resource. +type Grant struct { + Privileges []string `json:"privileges"` + + Principal string `json:"principal"` +} diff --git a/bundle/config/resources/model_serving_endpoint.go b/bundle/config/resources/model_serving_endpoint.go index 3847e6a6..88a55ac8 100644 --- a/bundle/config/resources/model_serving_endpoint.go +++ b/bundle/config/resources/model_serving_endpoint.go @@ -15,8 +15,8 @@ type ModelServingEndpoint struct { // as a reference in other resources. This value is returned by terraform. ID string - // Local path where the bundle is defined. All bundle resources include - // this for interpolation purposes. + // Path to config file where the resource is defined. All bundle resources + // include this for interpolation purposes. paths.Paths // This is a resource agnostic implementation of permissions for ACLs. diff --git a/bundle/config/resources/registered_model.go b/bundle/config/resources/registered_model.go new file mode 100644 index 00000000..32a451a2 --- /dev/null +++ b/bundle/config/resources/registered_model.go @@ -0,0 +1,34 @@ +package resources + +import ( + "github.com/databricks/cli/bundle/config/paths" + "github.com/databricks/databricks-sdk-go/marshal" + "github.com/databricks/databricks-sdk-go/service/catalog" +) + +type RegisteredModel struct { + // This is a resource agnostic implementation of grants. + // Implementation could be different based on the resource type. + Grants []Grant `json:"grants,omitempty"` + + // This represents the id which is the full name of the model + // (catalog_name.schema_name.model_name) that can be used + // as a reference in other resources. This value is returned by terraform. + ID string + + // Path to config file where the resource is defined. All bundle resources + // include this for interpolation purposes. + paths.Paths + + // This represents the input args for terraform, and will get converted + // to a HCL representation for CRUD + *catalog.CreateRegisteredModelRequest +} + +func (s *RegisteredModel) UnmarshalJSON(b []byte) error { + return marshal.Unmarshal(b, s) +} + +func (s RegisteredModel) MarshalJSON() ([]byte, error) { + return marshal.Marshal(s) +} diff --git a/bundle/config/resources_test.go b/bundle/config/resources_test.go index 82cb9f45..9c4104e4 100644 --- a/bundle/config/resources_test.go +++ b/bundle/config/resources_test.go @@ -95,3 +95,33 @@ func TestVerifySafeMergeForSameResourceType(t *testing.T) { err := r.VerifySafeMerge(&other) assert.ErrorContains(t, err, "multiple resources named foo (job at foo.yml, job at foo2.yml)") } + +func TestVerifySafeMergeForRegisteredModels(t *testing.T) { + r := Resources{ + Jobs: map[string]*resources.Job{ + "foo": { + Paths: paths.Paths{ + ConfigFilePath: "foo.yml", + }, + }, + }, + RegisteredModels: map[string]*resources.RegisteredModel{ + "bar": { + Paths: paths.Paths{ + ConfigFilePath: "bar.yml", + }, + }, + }, + } + other := Resources{ + RegisteredModels: map[string]*resources.RegisteredModel{ + "bar": { + Paths: paths.Paths{ + ConfigFilePath: "bar2.yml", + }, + }, + }, + } + err := r.VerifySafeMerge(&other) + assert.ErrorContains(t, err, "multiple resources named bar (registered_model at bar.yml, registered_model at bar2.yml)") +} diff --git a/bundle/deploy/terraform/convert.go b/bundle/deploy/terraform/convert.go index 7d95e719..3bfc8b83 100644 --- a/bundle/deploy/terraform/convert.go +++ b/bundle/deploy/terraform/convert.go @@ -44,6 +44,22 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon return dst } +func convGrants(acl []resources.Grant) *schema.ResourceGrants { + if len(acl) == 0 { + return nil + } + + resource := schema.ResourceGrants{} + for _, ac := range acl { + resource.Grant = append(resource.Grant, schema.ResourceGrantsGrant{ + Privileges: ac.Privileges, + Principal: ac.Principal, + }) + } + + return &resource +} + // BundleToTerraform converts resources in a bundle configuration // to the equivalent Terraform JSON representation. // @@ -174,6 +190,19 @@ func BundleToTerraform(config *config.Root) *schema.Root { } } + for k, src := range config.Resources.RegisteredModels { + noResources = false + var dst schema.ResourceRegisteredModel + conv(src, &dst) + tfroot.Resource.RegisteredModel[k] = &dst + + // Configure permissions for this resource. + if rp := convGrants(src.Grants); rp != nil { + rp.Function = fmt.Sprintf("${databricks_registered_model.%s.id}", k) + tfroot.Resource.Grants["registered_model_"+k] = rp + } + } + // We explicitly set "resource" to nil to omit it from a JSON encoding. // This is required because the terraform CLI requires >= 1 resources defined // if the "resource" property is used in a .tf.json file. @@ -221,7 +250,14 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error { cur := config.Resources.ModelServingEndpoints[resource.Name] conv(tmp, &cur) config.Resources.ModelServingEndpoints[resource.Name] = cur + case "databricks_registered_model": + var tmp schema.ResourceRegisteredModel + conv(resource.AttributeValues, &tmp) + cur := config.Resources.RegisteredModels[resource.Name] + conv(tmp, &cur) + config.Resources.RegisteredModels[resource.Name] = cur case "databricks_permissions": + case "databricks_grants": // Ignore; no need to pull these back into the configuration. default: return fmt.Errorf("missing mapping for %s", resource.Type) diff --git a/bundle/deploy/terraform/convert_test.go b/bundle/deploy/terraform/convert_test.go index b6b29f35..bb5a63ec 100644 --- a/bundle/deploy/terraform/convert_test.go +++ b/bundle/deploy/terraform/convert_test.go @@ -5,6 +5,7 @@ import ( "github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" @@ -366,3 +367,58 @@ func TestConvertModelServingPermissions(t *testing.T) { assert.Equal(t, "CAN_VIEW", p.PermissionLevel) } + +func TestConvertRegisteredModel(t *testing.T) { + var src = resources.RegisteredModel{ + CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{ + Name: "name", + CatalogName: "catalog", + SchemaName: "schema", + Comment: "comment", + }, + } + + var config = config.Root{ + Resources: config.Resources{ + RegisteredModels: map[string]*resources.RegisteredModel{ + "my_registered_model": &src, + }, + }, + } + + out := BundleToTerraform(&config) + resource := out.Resource.RegisteredModel["my_registered_model"] + assert.Equal(t, "name", resource.Name) + assert.Equal(t, "catalog", resource.CatalogName) + assert.Equal(t, "schema", resource.SchemaName) + assert.Equal(t, "comment", resource.Comment) + assert.Nil(t, out.Data) +} + +func TestConvertRegisteredModelGrants(t *testing.T) { + var src = resources.RegisteredModel{ + Grants: []resources.Grant{ + { + Privileges: []string{"EXECUTE"}, + Principal: "jane@doe.com", + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + RegisteredModels: map[string]*resources.RegisteredModel{ + "my_registered_model": &src, + }, + }, + } + + out := BundleToTerraform(&config) + assert.NotEmpty(t, out.Resource.Grants["registered_model_my_registered_model"].Function) + assert.Len(t, out.Resource.Grants["registered_model_my_registered_model"].Grant, 1) + + p := out.Resource.Grants["registered_model_my_registered_model"].Grant[0] + assert.Equal(t, "jane@doe.com", p.Principal) + assert.Equal(t, "EXECUTE", p.Privileges[0]) + +} diff --git a/bundle/deploy/terraform/interpolate.go b/bundle/deploy/terraform/interpolate.go index ea3c99aa..4f00c27e 100644 --- a/bundle/deploy/terraform/interpolate.go +++ b/bundle/deploy/terraform/interpolate.go @@ -28,6 +28,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri case "model_serving_endpoints": path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter) return fmt.Sprintf("${%s}", path), nil + case "registered_models": + path = strings.Join(append([]string{"databricks_registered_model"}, parts[2:]...), interpolation.Delimiter) + return fmt.Sprintf("${%s}", path), nil default: panic("TODO: " + parts[1]) } diff --git a/bundle/schema/openapi.go b/bundle/schema/openapi.go index 1a8b76ed..0b64c43e 100644 --- a/bundle/schema/openapi.go +++ b/bundle/schema/openapi.go @@ -223,6 +223,19 @@ func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) { return modelServingEndpointsAllDocs, nil } +func (reader *OpenapiReader) registeredModelDocs() (*Docs, error) { + registeredModelsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "catalog.CreateRegisteredModelRequest") + if err != nil { + return nil, err + } + registeredModelsDocs := schemaToDocs(registeredModelsSpecSchema) + registeredModelsAllDocs := &Docs{ + Description: "List of Registered Models", + AdditionalProperties: registeredModelsDocs, + } + return registeredModelsAllDocs, nil +} + func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { jobsDocs, err := reader.jobsDocs() if err != nil { @@ -244,6 +257,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { if err != nil { return nil, err } + registeredModelsDocs, err := reader.registeredModelDocs() + if err != nil { + return nil, err + } return &Docs{ Description: "Collection of Databricks resources to deploy.", @@ -253,6 +270,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { "experiments": experimentsDocs, "models": modelsDocs, "model_serving_endpoints": modelServingEndpointsDocs, + "registered_models": registeredModelsDocs, }, }, nil } diff --git a/bundle/tests/registered_model/databricks.yml b/bundle/tests/registered_model/databricks.yml new file mode 100644 index 00000000..b7b8ea5d --- /dev/null +++ b/bundle/tests/registered_model/databricks.yml @@ -0,0 +1,32 @@ +resources: + registered_models: + my_registered_model: + name: "my-model" + comment: "comment" + catalog_name: "main" + schema_name: "default" + grants: + - privileges: + - EXECUTE + principal: "account users" + +targets: + development: + mode: development + resources: + registered_models: + my_registered_model: + name: "my-dev-model" + + staging: + resources: + registered_models: + my_registered_model: + name: "my-staging-model" + + production: + mode: production + resources: + registered_models: + my_registered_model: + name: "my-prod-model" diff --git a/bundle/tests/registered_model_test.go b/bundle/tests/registered_model_test.go new file mode 100644 index 00000000..920a2ac7 --- /dev/null +++ b/bundle/tests/registered_model_test.go @@ -0,0 +1,47 @@ +package config_tests + +import ( + "path/filepath" + "testing" + + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/stretchr/testify/assert" +) + +func assertExpectedModel(t *testing.T, p *resources.RegisteredModel) { + assert.Equal(t, "registered_model/databricks.yml", filepath.ToSlash(p.ConfigFilePath)) + assert.Equal(t, "main", p.CatalogName) + assert.Equal(t, "default", p.SchemaName) + assert.Equal(t, "comment", p.Comment) + assert.Equal(t, "account users", p.Grants[0].Principal) + assert.Equal(t, "EXECUTE", p.Grants[0].Privileges[0]) +} + +func TestRegisteredModelDevelopment(t *testing.T) { + b := loadTarget(t, "./registered_model", "development") + assert.Len(t, b.Config.Resources.RegisteredModels, 1) + assert.Equal(t, b.Config.Bundle.Mode, config.Development) + + p := b.Config.Resources.RegisteredModels["my_registered_model"] + assert.Equal(t, "my-dev-model", p.Name) + assertExpectedModel(t, p) +} + +func TestRegisteredModelStaging(t *testing.T) { + b := loadTarget(t, "./registered_model", "staging") + assert.Len(t, b.Config.Resources.RegisteredModels, 1) + + p := b.Config.Resources.RegisteredModels["my_registered_model"] + assert.Equal(t, "my-staging-model", p.Name) + assertExpectedModel(t, p) +} + +func TestRegisteredModelProduction(t *testing.T) { + b := loadTarget(t, "./registered_model", "production") + assert.Len(t, b.Config.Resources.RegisteredModels, 1) + + p := b.Config.Resources.RegisteredModels["my_registered_model"] + assert.Equal(t, "my-prod-model", p.Name) + assertExpectedModel(t, p) +}