mirror of https://github.com/databricks/cli.git
Support Unity Catalog Registered Models in bundles (#846)
## Changes <!-- Summary of your changes that are easy to understand --> Add UC Registered Models support to Databricks Asset Bundles as new resource `registered_model`. Also added UC Permission support via new resource `grant`. ## Tests <!-- How is this tested? --> Tested via unit tests and manual testing with [example PR](https://github.com/databricks/bundle-examples-internal/pull/80) and [custom Terraform provider](https://github.com/databricks/terraform-provider-databricks/pull/2771). <img width="698" alt="Screenshot 2023-10-08 at 4 57 23 PM" src="https://github.com/databricks/cli/assets/87999496/bcf605a9-7894-443b-865a-f7e240037815"> <img width="1109" alt="Screenshot 2023-10-08 at 4 56 47 PM" src="https://github.com/databricks/cli/assets/87999496/e4d6e424-cd70-4809-8843-6939ed2e172f"> <img width="1091" alt="Screenshot 2023-10-08 at 4 56 57 PM" src="https://github.com/databricks/cli/assets/87999496/88ebaabb-67db-4a11-88a5-df087e2e41c0"> --------- Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com> Co-authored-by: Andrew Nester <andrew.nester.dev@gmail.com> Co-authored-by: Pieter Noordhuis <pieter.noordhuis@databricks.com>
This commit is contained in:
parent
61cf4fbe8d
commit
24cc67563e
3
Makefile
3
Makefile
|
@ -30,4 +30,5 @@ vendor:
|
|||
@echo "✓ Filling vendor folder with library code ..."
|
||||
@go mod vendor
|
||||
|
||||
.PHONY: build vendor coverage test lint fmt
|
||||
.PHONY: build vendor coverage test lint fmt
|
||||
|
||||
|
|
|
@ -87,6 +87,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
|
|||
// (model serving doesn't yet support tags)
|
||||
}
|
||||
|
||||
for i := range r.RegisteredModels {
|
||||
prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
|
||||
r.RegisteredModels[i].Name = prefix + r.RegisteredModels[i].Name
|
||||
// (registered models in Unity Catalog don't yet support tags)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/databricks/cli/libs/tags"
|
||||
sdkconfig "github.com/databricks/databricks-sdk-go/config"
|
||||
"github.com/databricks/databricks-sdk-go/service/catalog"
|
||||
"github.com/databricks/databricks-sdk-go/service/iam"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/databricks/databricks-sdk-go/service/ml"
|
||||
|
@ -59,6 +60,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
|
|||
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
|
||||
"servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
|
||||
},
|
||||
RegisteredModels: map[string]*resources.RegisteredModel{
|
||||
"registeredmodel1": {CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{Name: "registeredmodel1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Use AWS implementation for testing.
|
||||
|
@ -86,6 +90,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
|
|||
// Experiment 1
|
||||
assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
|
||||
assert.Contains(t, bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags, ml.ExperimentTag{Key: "dev", Value: "lennart"})
|
||||
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
|
||||
|
||||
// Experiment 2
|
||||
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
|
||||
|
@ -96,7 +101,9 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
|
|||
|
||||
// Model serving endpoint 1
|
||||
assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
|
||||
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
|
||||
|
||||
// Registered model 1
|
||||
assert.Equal(t, "dev_lennart_registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
|
||||
}
|
||||
|
||||
func TestProcessTargetModeDevelopmentTagNormalizationForAws(t *testing.T) {
|
||||
|
@ -151,6 +158,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
|
|||
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
|
||||
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
|
||||
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
|
||||
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
|
||||
}
|
||||
|
||||
func TestProcessTargetModeProduction(t *testing.T) {
|
||||
|
@ -187,6 +195,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
|
|||
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
|
||||
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
|
||||
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
|
||||
assert.Equal(t, "registeredmodel1", bundle.Config.Resources.RegisteredModels["registeredmodel1"].Name)
|
||||
}
|
||||
|
||||
func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {
|
||||
|
|
|
@ -14,6 +14,7 @@ type Resources struct {
|
|||
Models map[string]*resources.MlflowModel `json:"models,omitempty"`
|
||||
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
|
||||
ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
|
||||
RegisteredModels map[string]*resources.RegisteredModel `json:"registered_models,omitempty"`
|
||||
}
|
||||
|
||||
type UniqueResourceIdTracker struct {
|
||||
|
@ -107,6 +108,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
|
|||
tracker.Type[k] = "model_serving_endpoint"
|
||||
tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
|
||||
}
|
||||
for k := range r.RegisteredModels {
|
||||
if _, ok := tracker.Type[k]; ok {
|
||||
return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
|
||||
k,
|
||||
tracker.Type[k],
|
||||
tracker.ConfigPath[k],
|
||||
"registered_model",
|
||||
r.RegisteredModels[k].ConfigFilePath,
|
||||
)
|
||||
}
|
||||
tracker.Type[k] = "registered_model"
|
||||
tracker.ConfigPath[k] = r.RegisteredModels[k].ConfigFilePath
|
||||
}
|
||||
return tracker, nil
|
||||
}
|
||||
|
||||
|
@ -129,6 +143,9 @@ func (r *Resources) SetConfigFilePath(path string) {
|
|||
for _, e := range r.ModelServingEndpoints {
|
||||
e.ConfigFilePath = path
|
||||
}
|
||||
for _, e := range r.RegisteredModels {
|
||||
e.ConfigFilePath = path
|
||||
}
|
||||
}
|
||||
|
||||
// Merge iterates over all resources and merges chunks of the
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package resources
|
||||
|
||||
// Grant holds the grant level settings for a single principal in Unity Catalog.
|
||||
// Multiple of these can be defined on any Unity Catalog resource.
|
||||
type Grant struct {
|
||||
Privileges []string `json:"privileges"`
|
||||
|
||||
Principal string `json:"principal"`
|
||||
}
|
|
@ -15,8 +15,8 @@ type ModelServingEndpoint struct {
|
|||
// as a reference in other resources. This value is returned by terraform.
|
||||
ID string
|
||||
|
||||
// Local path where the bundle is defined. All bundle resources include
|
||||
// this for interpolation purposes.
|
||||
// Path to config file where the resource is defined. All bundle resources
|
||||
// include this for interpolation purposes.
|
||||
paths.Paths
|
||||
|
||||
// This is a resource agnostic implementation of permissions for ACLs.
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package resources
|
||||
|
||||
import (
|
||||
"github.com/databricks/cli/bundle/config/paths"
|
||||
"github.com/databricks/databricks-sdk-go/marshal"
|
||||
"github.com/databricks/databricks-sdk-go/service/catalog"
|
||||
)
|
||||
|
||||
type RegisteredModel struct {
|
||||
// This is a resource agnostic implementation of grants.
|
||||
// Implementation could be different based on the resource type.
|
||||
Grants []Grant `json:"grants,omitempty"`
|
||||
|
||||
// This represents the id which is the full name of the model
|
||||
// (catalog_name.schema_name.model_name) that can be used
|
||||
// as a reference in other resources. This value is returned by terraform.
|
||||
ID string
|
||||
|
||||
// Path to config file where the resource is defined. All bundle resources
|
||||
// include this for interpolation purposes.
|
||||
paths.Paths
|
||||
|
||||
// This represents the input args for terraform, and will get converted
|
||||
// to a HCL representation for CRUD
|
||||
*catalog.CreateRegisteredModelRequest
|
||||
}
|
||||
|
||||
func (s *RegisteredModel) UnmarshalJSON(b []byte) error {
|
||||
return marshal.Unmarshal(b, s)
|
||||
}
|
||||
|
||||
func (s RegisteredModel) MarshalJSON() ([]byte, error) {
|
||||
return marshal.Marshal(s)
|
||||
}
|
|
@ -95,3 +95,33 @@ func TestVerifySafeMergeForSameResourceType(t *testing.T) {
|
|||
err := r.VerifySafeMerge(&other)
|
||||
assert.ErrorContains(t, err, "multiple resources named foo (job at foo.yml, job at foo2.yml)")
|
||||
}
|
||||
|
||||
func TestVerifySafeMergeForRegisteredModels(t *testing.T) {
|
||||
r := Resources{
|
||||
Jobs: map[string]*resources.Job{
|
||||
"foo": {
|
||||
Paths: paths.Paths{
|
||||
ConfigFilePath: "foo.yml",
|
||||
},
|
||||
},
|
||||
},
|
||||
RegisteredModels: map[string]*resources.RegisteredModel{
|
||||
"bar": {
|
||||
Paths: paths.Paths{
|
||||
ConfigFilePath: "bar.yml",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
other := Resources{
|
||||
RegisteredModels: map[string]*resources.RegisteredModel{
|
||||
"bar": {
|
||||
Paths: paths.Paths{
|
||||
ConfigFilePath: "bar2.yml",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := r.VerifySafeMerge(&other)
|
||||
assert.ErrorContains(t, err, "multiple resources named bar (registered_model at bar.yml, registered_model at bar2.yml)")
|
||||
}
|
||||
|
|
|
@ -44,6 +44,22 @@ func convPermission(ac resources.Permission) schema.ResourcePermissionsAccessCon
|
|||
return dst
|
||||
}
|
||||
|
||||
func convGrants(acl []resources.Grant) *schema.ResourceGrants {
|
||||
if len(acl) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
resource := schema.ResourceGrants{}
|
||||
for _, ac := range acl {
|
||||
resource.Grant = append(resource.Grant, schema.ResourceGrantsGrant{
|
||||
Privileges: ac.Privileges,
|
||||
Principal: ac.Principal,
|
||||
})
|
||||
}
|
||||
|
||||
return &resource
|
||||
}
|
||||
|
||||
// BundleToTerraform converts resources in a bundle configuration
|
||||
// to the equivalent Terraform JSON representation.
|
||||
//
|
||||
|
@ -174,6 +190,19 @@ func BundleToTerraform(config *config.Root) *schema.Root {
|
|||
}
|
||||
}
|
||||
|
||||
for k, src := range config.Resources.RegisteredModels {
|
||||
noResources = false
|
||||
var dst schema.ResourceRegisteredModel
|
||||
conv(src, &dst)
|
||||
tfroot.Resource.RegisteredModel[k] = &dst
|
||||
|
||||
// Configure permissions for this resource.
|
||||
if rp := convGrants(src.Grants); rp != nil {
|
||||
rp.Function = fmt.Sprintf("${databricks_registered_model.%s.id}", k)
|
||||
tfroot.Resource.Grants["registered_model_"+k] = rp
|
||||
}
|
||||
}
|
||||
|
||||
// We explicitly set "resource" to nil to omit it from a JSON encoding.
|
||||
// This is required because the terraform CLI requires >= 1 resources defined
|
||||
// if the "resource" property is used in a .tf.json file.
|
||||
|
@ -221,7 +250,14 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
|
|||
cur := config.Resources.ModelServingEndpoints[resource.Name]
|
||||
conv(tmp, &cur)
|
||||
config.Resources.ModelServingEndpoints[resource.Name] = cur
|
||||
case "databricks_registered_model":
|
||||
var tmp schema.ResourceRegisteredModel
|
||||
conv(resource.AttributeValues, &tmp)
|
||||
cur := config.Resources.RegisteredModels[resource.Name]
|
||||
conv(tmp, &cur)
|
||||
config.Resources.RegisteredModels[resource.Name] = cur
|
||||
case "databricks_permissions":
|
||||
case "databricks_grants":
|
||||
// Ignore; no need to pull these back into the configuration.
|
||||
default:
|
||||
return fmt.Errorf("missing mapping for %s", resource.Type)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/databricks/cli/bundle/config"
|
||||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/databricks/databricks-sdk-go/service/catalog"
|
||||
"github.com/databricks/databricks-sdk-go/service/compute"
|
||||
"github.com/databricks/databricks-sdk-go/service/jobs"
|
||||
"github.com/databricks/databricks-sdk-go/service/ml"
|
||||
|
@ -366,3 +367,58 @@ func TestConvertModelServingPermissions(t *testing.T) {
|
|||
assert.Equal(t, "CAN_VIEW", p.PermissionLevel)
|
||||
|
||||
}
|
||||
|
||||
func TestConvertRegisteredModel(t *testing.T) {
|
||||
var src = resources.RegisteredModel{
|
||||
CreateRegisteredModelRequest: &catalog.CreateRegisteredModelRequest{
|
||||
Name: "name",
|
||||
CatalogName: "catalog",
|
||||
SchemaName: "schema",
|
||||
Comment: "comment",
|
||||
},
|
||||
}
|
||||
|
||||
var config = config.Root{
|
||||
Resources: config.Resources{
|
||||
RegisteredModels: map[string]*resources.RegisteredModel{
|
||||
"my_registered_model": &src,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out := BundleToTerraform(&config)
|
||||
resource := out.Resource.RegisteredModel["my_registered_model"]
|
||||
assert.Equal(t, "name", resource.Name)
|
||||
assert.Equal(t, "catalog", resource.CatalogName)
|
||||
assert.Equal(t, "schema", resource.SchemaName)
|
||||
assert.Equal(t, "comment", resource.Comment)
|
||||
assert.Nil(t, out.Data)
|
||||
}
|
||||
|
||||
func TestConvertRegisteredModelGrants(t *testing.T) {
|
||||
var src = resources.RegisteredModel{
|
||||
Grants: []resources.Grant{
|
||||
{
|
||||
Privileges: []string{"EXECUTE"},
|
||||
Principal: "jane@doe.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var config = config.Root{
|
||||
Resources: config.Resources{
|
||||
RegisteredModels: map[string]*resources.RegisteredModel{
|
||||
"my_registered_model": &src,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out := BundleToTerraform(&config)
|
||||
assert.NotEmpty(t, out.Resource.Grants["registered_model_my_registered_model"].Function)
|
||||
assert.Len(t, out.Resource.Grants["registered_model_my_registered_model"].Grant, 1)
|
||||
|
||||
p := out.Resource.Grants["registered_model_my_registered_model"].Grant[0]
|
||||
assert.Equal(t, "jane@doe.com", p.Principal)
|
||||
assert.Equal(t, "EXECUTE", p.Privileges[0])
|
||||
|
||||
}
|
||||
|
|
|
@ -28,6 +28,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
|
|||
case "model_serving_endpoints":
|
||||
path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
|
||||
return fmt.Sprintf("${%s}", path), nil
|
||||
case "registered_models":
|
||||
path = strings.Join(append([]string{"databricks_registered_model"}, parts[2:]...), interpolation.Delimiter)
|
||||
return fmt.Sprintf("${%s}", path), nil
|
||||
default:
|
||||
panic("TODO: " + parts[1])
|
||||
}
|
||||
|
|
|
@ -223,6 +223,19 @@ func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
|
|||
return modelServingEndpointsAllDocs, nil
|
||||
}
|
||||
|
||||
func (reader *OpenapiReader) registeredModelDocs() (*Docs, error) {
|
||||
registeredModelsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "catalog.CreateRegisteredModelRequest")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registeredModelsDocs := schemaToDocs(registeredModelsSpecSchema)
|
||||
registeredModelsAllDocs := &Docs{
|
||||
Description: "List of Registered Models",
|
||||
AdditionalProperties: registeredModelsDocs,
|
||||
}
|
||||
return registeredModelsAllDocs, nil
|
||||
}
|
||||
|
||||
func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
|
||||
jobsDocs, err := reader.jobsDocs()
|
||||
if err != nil {
|
||||
|
@ -244,6 +257,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registeredModelsDocs, err := reader.registeredModelDocs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Docs{
|
||||
Description: "Collection of Databricks resources to deploy.",
|
||||
|
@ -253,6 +270,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
|
|||
"experiments": experimentsDocs,
|
||||
"models": modelsDocs,
|
||||
"model_serving_endpoints": modelServingEndpointsDocs,
|
||||
"registered_models": registeredModelsDocs,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
resources:
|
||||
registered_models:
|
||||
my_registered_model:
|
||||
name: "my-model"
|
||||
comment: "comment"
|
||||
catalog_name: "main"
|
||||
schema_name: "default"
|
||||
grants:
|
||||
- privileges:
|
||||
- EXECUTE
|
||||
principal: "account users"
|
||||
|
||||
targets:
|
||||
development:
|
||||
mode: development
|
||||
resources:
|
||||
registered_models:
|
||||
my_registered_model:
|
||||
name: "my-dev-model"
|
||||
|
||||
staging:
|
||||
resources:
|
||||
registered_models:
|
||||
my_registered_model:
|
||||
name: "my-staging-model"
|
||||
|
||||
production:
|
||||
mode: production
|
||||
resources:
|
||||
registered_models:
|
||||
my_registered_model:
|
||||
name: "my-prod-model"
|
|
@ -0,0 +1,47 @@
|
|||
package config_tests
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/cli/bundle/config"
|
||||
"github.com/databricks/cli/bundle/config/resources"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func assertExpectedModel(t *testing.T, p *resources.RegisteredModel) {
|
||||
assert.Equal(t, "registered_model/databricks.yml", filepath.ToSlash(p.ConfigFilePath))
|
||||
assert.Equal(t, "main", p.CatalogName)
|
||||
assert.Equal(t, "default", p.SchemaName)
|
||||
assert.Equal(t, "comment", p.Comment)
|
||||
assert.Equal(t, "account users", p.Grants[0].Principal)
|
||||
assert.Equal(t, "EXECUTE", p.Grants[0].Privileges[0])
|
||||
}
|
||||
|
||||
func TestRegisteredModelDevelopment(t *testing.T) {
|
||||
b := loadTarget(t, "./registered_model", "development")
|
||||
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
|
||||
assert.Equal(t, b.Config.Bundle.Mode, config.Development)
|
||||
|
||||
p := b.Config.Resources.RegisteredModels["my_registered_model"]
|
||||
assert.Equal(t, "my-dev-model", p.Name)
|
||||
assertExpectedModel(t, p)
|
||||
}
|
||||
|
||||
func TestRegisteredModelStaging(t *testing.T) {
|
||||
b := loadTarget(t, "./registered_model", "staging")
|
||||
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
|
||||
|
||||
p := b.Config.Resources.RegisteredModels["my_registered_model"]
|
||||
assert.Equal(t, "my-staging-model", p.Name)
|
||||
assertExpectedModel(t, p)
|
||||
}
|
||||
|
||||
func TestRegisteredModelProduction(t *testing.T) {
|
||||
b := loadTarget(t, "./registered_model", "production")
|
||||
assert.Len(t, b.Config.Resources.RegisteredModels, 1)
|
||||
|
||||
p := b.Config.Resources.RegisteredModels["my_registered_model"]
|
||||
assert.Equal(t, "my-prod-model", p.Name)
|
||||
assertExpectedModel(t, p)
|
||||
}
|
Loading…
Reference in New Issue