Support Unity Catalog Registered Models in bundles (#846)

## Changes
<!-- Summary of your changes that are easy to understand -->
Add UC Registered Models support to Databricks Asset Bundles as new
resource `registered_model`. Also added UC Permission support via new
resource `grant`.

## Tests
<!-- How is this tested? -->
Tested via unit tests and manual testing with [example
PR](https://github.com/databricks/bundle-examples-internal/pull/80) and
[custom Terraform
provider](https://github.com/databricks/terraform-provider-databricks/pull/2771).
<img width="698" alt="Screenshot 2023-10-08 at 4 57 23 PM"
src="https://github.com/databricks/cli/assets/87999496/bcf605a9-7894-443b-865a-f7e240037815">
<img width="1109" alt="Screenshot 2023-10-08 at 4 56 47 PM"
src="https://github.com/databricks/cli/assets/87999496/e4d6e424-cd70-4809-8843-6939ed2e172f">
<img width="1091" alt="Screenshot 2023-10-08 at 4 56 57 PM"
src="https://github.com/databricks/cli/assets/87999496/88ebaabb-67db-4a11-88a5-df087e2e41c0">

---------

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>
Co-authored-by: Andrew Nester <andrew.nester.dev@gmail.com>
Co-authored-by: Pieter Noordhuis <pieter.noordhuis@databricks.com>
This commit is contained in:
Arpit Jasapara 2023-10-16 08:32:49 -07:00 committed by GitHub
parent 61cf4fbe8d
commit 24cc67563e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 302 additions and 4 deletions

View File

@ -30,4 +30,5 @@ vendor:
@echo "✓ Filling vendor folder with library code ..."
@go mod vendor
.PHONY: build vendor coverage test lint fmt
.PHONY: build vendor coverage test lint fmt

View File

@ -87,6 +87,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
// (model serving doesn't yet support tags)
}
for i := range r.RegisteredModels {
prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
r.RegisteredModels[i].Name = prefix + r.RegisteredModels[i].Name
// (registered models in Unity Catalog don't yet support tags)
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/tags"
sdkconfig "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
@ -59,6 +60,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
},
RegisteredModels: map[string]*resources.RegisteredModel{
"registeredmodel1": {CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{Name: "registeredmodel1"}},
},
},
},
// Use AWS implementation for testing.
@ -86,6 +90,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
// Experiment 1
assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
assert.Contains(t, bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags, ml.ExperimentTag{Key: "dev", Value: "lennart"})
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
// Experiment 2
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
@ -96,7 +101,9 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
// Model serving endpoint 1
assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
// Registered model 1
assert.Equal(t, "dev_lennart_registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}
func TestProcessTargetModeDevelopmentTagNormalizationForAws(t *testing.T) {
@ -151,6 +158,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}
func TestProcessTargetModeProduction(t *testing.T) {
@ -187,6 +195,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
}
func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {

View File

@ -14,6 +14,7 @@ type Resources struct {
Models map[string]*resources.MlflowModel `json:"models,omitempty"`
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
RegisteredModels map[string]*resources.RegisteredModel `json:"registered_models,omitempty"`
}
type UniqueResourceIdTracker struct {
@ -107,6 +108,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
tracker.Type[k] = "model_serving_endpoint"
tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
}
for k := range r.RegisteredModels {
if _, ok := tracker.Type[k]; ok {
return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
k,
tracker.Type[k],
tracker.ConfigPath[k],
"registered_model",
r.RegisteredModels[k].ConfigFilePath,
)
}
tracker.Type[k] = "registered_model"
tracker.ConfigPath[k] = r.RegisteredModels[k].ConfigFilePath
}
return tracker, nil
}
@ -129,6 +143,9 @@ func (r *Resources) SetConfigFilePath(path string) {
for _, e := range r.ModelServingEndpoints {
e.ConfigFilePath = path
}
for _, e := range r.RegisteredModels {
e.ConfigFilePath = path
}
}
// Merge iterates over all resources and merges chunks of the

View File

@ -0,0 +1,9 @@
package resources
// Grant holds the grant level settings for a single principal in Unity Catalog.
// Multiple of these can be defined on any Unity Catalog resource.
type Grant struct {
Privileges []string `json:"privileges"`
Principal string `json:"principal"`
}

View File

@ -15,8 +15,8 @@ type ModelServingEndpoint struct {
// as a reference in other resources. This value is returned by terraform.
ID string
// Local path where the bundle is defined. All bundle resources include
// this for interpolation purposes.
// Path to config file where the resource is defined. All bundle resources
// include this for interpolation purposes.
paths.Paths
// This is a resource agnostic implementation of permissions for ACLs.

View File

@ -0,0 +1,34 @@
package resources
import (
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/databricks-sdk-go/marshal"
"github.com/databricks/databricks-sdk-go/service/catalog"
)
type RegisteredModel struct {
// This is a resource agnostic implementation of grants.
// Implementation could be different based on the resource type.
Grants []Grant `json:"grants,omitempty"`
// This represents the id which is the full name of the model
// (catalog_name.schema_name.model_name) that can be used
// as a reference in other resources. This value is returned by terraform.
ID string
// Path to config file where the resource is defined. All bundle resources
// include this for interpolation purposes.
paths.Paths
// This represents the input args for terraform, and will get converted
// to a HCL representation for CRUD
*catalog.CreateRegisteredModelRequest
}
func (s *RegisteredModel) UnmarshalJSON(b []byte) error {
return marshal.Unmarshal(b, s)
}
func (s RegisteredModel) MarshalJSON() ([]byte, error) {
return marshal.Marshal(s)
}

View File

@ -95,3 +95,33 @@ func TestVerifySafeMergeForSameResourceType(t *testing.T) {
err := r.VerifySafeMerge(&other)
assert.ErrorContains(t, err, "multiple resources named foo (job at foo.yml, job at foo2.yml)")
}
func TestVerifySafeMergeForRegisteredModels(t *testing.T) {
r := Resources{
Jobs: map[string]*resources.Job{
"foo": {
Paths: paths.Paths{
ConfigFilePath: "foo.yml",
},
},
},
RegisteredModels: map[string]*resources.RegisteredModel{
"bar": {
Paths: paths.Paths{
ConfigFilePath: "bar.yml",
},
},
},
}
other := Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"bar": {
Paths: paths.Paths{
ConfigFilePath: "bar2.yml",
},
},
},
}
err := r.VerifySafeMerge(&other)
assert.ErrorContains(t, err, "multiple resources named bar (registered_model at bar.yml, registered_model at bar2.yml)")
}

View File

@ -44,6 +44,22 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon
return dst
}
func convGrants(acl []resources.Grant) *schema.ResourceGrants {
if len(acl) == 0 {
return nil
}
resource := schema.ResourceGrants{}
for _, ac := range acl {
resource.Grant = append(resource.Grant, schema.ResourceGrantsGrant{
Privileges: ac.Privileges,
Principal: ac.Principal,
})
}
return &resource
}
// BundleToTerraform converts resources in a bundle configuration
// to the equivalent Terraform JSON representation.
//
@ -174,6 +190,19 @@ func BundleToTerraform(config *config.Root) *schema.Root {
}
}
for k, src := range config.Resources.RegisteredModels {
noResources = false
var dst schema.ResourceRegisteredModel
conv(src, &dst)
tfroot.Resource.RegisteredModel[k] = &dst
// Configure permissions for this resource.
if rp := convGrants(src.Grants); rp != nil {
rp.Function = fmt.Sprintf("${databricks_registered_model.%s.id}", k)
tfroot.Resource.Grants["registered_model_"+k] = rp
}
}
// We explicitly set "resource" to nil to omit it from a JSON encoding.
// This is required because the terraform CLI requires >= 1 resources defined
// if the "resource" property is used in a .tf.json file.
@ -221,7 +250,14 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.ModelServingEndpoints[resource.Name]
conv(tmp, &cur)
config.Resources.ModelServingEndpoints[resource.Name] = cur
case "databricks_registered_model":
var tmp schema.ResourceRegisteredModel
conv(resource.AttributeValues, &tmp)
cur := config.Resources.RegisteredModels[resource.Name]
conv(tmp, &cur)
config.Resources.RegisteredModels[resource.Name] = cur
case "databricks_permissions":
case "databricks_grants":
// Ignore; no need to pull these back into the configuration.
default:
return fmt.Errorf("missing mapping for %s", resource.Type)

View File

@ -5,6 +5,7 @@ import (
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
@ -366,3 +367,58 @@ func TestConvertModelServingPermissions(t *testing.T) {
assert.Equal(t, "CAN_VIEW", p.PermissionLevel)
}
func TestConvertRegisteredModel(t *testing.T) {
var src = resources.RegisteredModel{
CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{
Name: "name",
CatalogName: "catalog",
SchemaName: "schema",
Comment: "comment",
},
}
var config = config.Root{
Resources: config.Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"my_registered_model": &src,
},
},
}
out := BundleToTerraform(&config)
resource := out.Resource.RegisteredModel["my_registered_model"]
assert.Equal(t, "name", resource.Name)
assert.Equal(t, "catalog", resource.CatalogName)
assert.Equal(t, "schema", resource.SchemaName)
assert.Equal(t, "comment", resource.Comment)
assert.Nil(t, out.Data)
}
func TestConvertRegisteredModelGrants(t *testing.T) {
var src = resources.RegisteredModel{
Grants: []resources.Grant{
{
Privileges: []string{"EXECUTE"},
Principal: "jane@doe.com",
},
},
}
var config = config.Root{
Resources: config.Resources{
RegisteredModels: map[string]*resources.RegisteredModel{
"my_registered_model": &src,
},
},
}
out := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Grants["registered_model_my_registered_model"].Function)
assert.Len(t, out.Resource.Grants["registered_model_my_registered_model"].Grant, 1)
p := out.Resource.Grants["registered_model_my_registered_model"].Grant[0]
assert.Equal(t, "jane@doe.com", p.Principal)
assert.Equal(t, "EXECUTE", p.Privileges[0])
}

View File

@ -28,6 +28,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "model_serving_endpoints":
path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
case "registered_models":
path = strings.Join(append([]string{"databricks_registered_model"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
default:
panic("TODO: " + parts[1])
}

View File

@ -223,6 +223,19 @@ func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
return modelServingEndpointsAllDocs, nil
}
func (reader *OpenapiReader) registeredModelDocs() (*Docs, error) {
registeredModelsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "catalog.CreateRegisteredModelRequest")
if err != nil {
return nil, err
}
registeredModelsDocs := schemaToDocs(registeredModelsSpecSchema)
registeredModelsAllDocs := &Docs{
Description: "List of Registered Models",
AdditionalProperties: registeredModelsDocs,
}
return registeredModelsAllDocs, nil
}
func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
jobsDocs, err := reader.jobsDocs()
if err != nil {
@ -244,6 +257,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
if err != nil {
return nil, err
}
registeredModelsDocs, err := reader.registeredModelDocs()
if err != nil {
return nil, err
}
return &Docs{
Description: "Collection of Databricks resources to deploy.",
@ -253,6 +270,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
"experiments": experimentsDocs,
"models": modelsDocs,
"model_serving_endpoints": modelServingEndpointsDocs,
"registered_models": registeredModelsDocs,
},
}, nil
}

View File

@ -0,0 +1,32 @@
resources:
registered_models:
my_registered_model:
name: "my-model"
comment: "comment"
catalog_name: "main"
schema_name: "default"
grants:
- privileges:
- EXECUTE
principal: "account users"
targets:
development:
mode: development
resources:
registered_models:
my_registered_model:
name: "my-dev-model"
staging:
resources:
registered_models:
my_registered_model:
name: "my-staging-model"
production:
mode: production
resources:
registered_models:
my_registered_model:
name: "my-prod-model"

View File

@ -0,0 +1,47 @@
package config_tests
import (
"path/filepath"
"testing"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/stretchr/testify/assert"
)
func assertExpectedModel(t *testing.T, p *resources.RegisteredModel) {
assert.Equal(t, "registered_model/databricks.yml", filepath.ToSlash(p.ConfigFilePath))
assert.Equal(t, "main", p.CatalogName)
assert.Equal(t, "default", p.SchemaName)
assert.Equal(t, "comment", p.Comment)
assert.Equal(t, "account users", p.Grants[0].Principal)
assert.Equal(t, "EXECUTE", p.Grants[0].Privileges[0])
}
func TestRegisteredModelDevelopment(t *testing.T) {
b := loadTarget(t, "./registered_model", "development")
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
assert.Equal(t, b.Config.Bundle.Mode, config.Development)
p := b.Config.Resources.RegisteredModels["my_registered_model"]
assert.Equal(t, "my-dev-model", p.Name)
assertExpectedModel(t, p)
}
func TestRegisteredModelStaging(t *testing.T) {
b := loadTarget(t, "./registered_model", "staging")
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
p := b.Config.Resources.RegisteredModels["my_registered_model"]
assert.Equal(t, "my-staging-model", p.Name)
assertExpectedModel(t, p)
}
func TestRegisteredModelProduction(t *testing.T) {
b := loadTarget(t, "./registered_model", "production")
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
p := b.Config.Resources.RegisteredModels["my_registered_model"]
assert.Equal(t, "my-prod-model", p.Name)
assertExpectedModel(t, p)
}