From 50eaf16307bae42a08edb506a9b9430de3eb0f1b Mon Sep 17 00:00:00 2001
From: Arpit Jasapara <87999496+arpitjasa-db@users.noreply.github.com>
Date: Thu, 7 Sep 2023 14:54:31 -0700
Subject: [PATCH] Support Model Serving Endpoints in bundles (#682)
## Changes
Add Model Serving Endpoints to Databricks Bundles
## Tests
Unit tests and manual testing via
https://github.com/databricks/bundle-examples-internal/pull/76
Signed-off-by: Arpit Jasapara
---
bundle/config/mutator/process_target_mode.go | 6 ++
.../mutator/process_target_mode_test.go | 8 ++
bundle/config/resources.go | 21 ++++-
.../resources/model_serving_endpoint.go | 24 ++++++
bundle/deploy/terraform/convert.go | 19 +++++
bundle/deploy/terraform/convert_test.go | 74 +++++++++++++++++
bundle/deploy/terraform/interpolate.go | 3 +
bundle/schema/docs/bundle_descriptions.json | 81 +++++++++++++++++++
bundle/schema/openapi.go | 26 +++++-
.../model_serving_endpoint/databricks.yml | 38 +++++++++
bundle/tests/model_serving_endpoint_test.go | 48 +++++++++++
11 files changed, 342 insertions(+), 6 deletions(-)
create mode 100644 bundle/config/resources/model_serving_endpoint.go
create mode 100644 bundle/tests/model_serving_endpoint/databricks.yml
create mode 100644 bundle/tests/model_serving_endpoint_test.go
diff --git a/bundle/config/mutator/process_target_mode.go b/bundle/config/mutator/process_target_mode.go
index 06ae7b85..93149ad0 100644
--- a/bundle/config/mutator/process_target_mode.go
+++ b/bundle/config/mutator/process_target_mode.go
@@ -77,6 +77,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName})
}
+ for i := range r.ModelServingEndpoints {
+ prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
+ r.ModelServingEndpoints[i].Name = prefix + r.ModelServingEndpoints[i].Name
+ // (model serving doesn't yet support tags)
+ }
+
return nil
}
diff --git a/bundle/config/mutator/process_target_mode_test.go b/bundle/config/mutator/process_target_mode_test.go
index 489632e1..4ea33c70 100644
--- a/bundle/config/mutator/process_target_mode_test.go
+++ b/bundle/config/mutator/process_target_mode_test.go
@@ -13,6 +13,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines"
+ "github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -53,6 +54,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
Models: map[string]*resources.MlflowModel{
"model1": {Model: &ml.Model{Name: "model1"}},
},
+ ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
+ "servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
+ },
},
},
}
@@ -69,6 +73,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name)
+ assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
}
@@ -82,6 +87,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
+ assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
}
func TestProcessTargetModeProduction(t *testing.T) {
@@ -109,6 +115,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions
bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions
bundle.Config.Resources.Models["model1"].Permissions = permissions
+ bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions
err = validateProductionMode(context.Background(), bundle, false)
require.NoError(t, err)
@@ -116,6 +123,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
+ assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
}
func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {
diff --git a/bundle/config/resources.go b/bundle/config/resources.go
index 5d47b918..c239b510 100644
--- a/bundle/config/resources.go
+++ b/bundle/config/resources.go
@@ -11,8 +11,9 @@ type Resources struct {
Jobs map[string]*resources.Job `json:"jobs,omitempty"`
Pipelines map[string]*resources.Pipeline `json:"pipelines,omitempty"`
- Models map[string]*resources.MlflowModel `json:"models,omitempty"`
- Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
+ Models map[string]*resources.MlflowModel `json:"models,omitempty"`
+ Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
+ ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
}
type UniqueResourceIdTracker struct {
@@ -93,6 +94,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
tracker.Type[k] = "mlflow_experiment"
tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath
}
+ for k := range r.ModelServingEndpoints {
+ if _, ok := tracker.Type[k]; ok {
+ return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
+ k,
+ tracker.Type[k],
+ tracker.ConfigPath[k],
+ "model_serving_endpoint",
+ r.ModelServingEndpoints[k].ConfigFilePath,
+ )
+ }
+ tracker.Type[k] = "model_serving_endpoint"
+ tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
+ }
return tracker, nil
}
@@ -112,6 +126,9 @@ func (r *Resources) SetConfigFilePath(path string) {
for _, e := range r.Experiments {
e.ConfigFilePath = path
}
+ for _, e := range r.ModelServingEndpoints {
+ e.ConfigFilePath = path
+ }
}
// MergeJobClusters iterates over all jobs and merges their job clusters.
diff --git a/bundle/config/resources/model_serving_endpoint.go b/bundle/config/resources/model_serving_endpoint.go
new file mode 100644
index 00000000..dccecaa6
--- /dev/null
+++ b/bundle/config/resources/model_serving_endpoint.go
@@ -0,0 +1,24 @@
+package resources
+
+import (
+ "github.com/databricks/cli/bundle/config/paths"
+ "github.com/databricks/databricks-sdk-go/service/serving"
+)
+
+type ModelServingEndpoint struct {
+ // This represents the input args for terraform, and will get converted
+ // to a HCL representation for CRUD
+ *serving.CreateServingEndpoint
+
+ // This represents the id (ie serving_endpoint_id) that can be used
+ // as a reference in other resources. This value is returned by terraform.
+ ID string
+
+ // Local path where the bundle is defined. All bundle resources include
+ // this for interpolation purposes.
+ paths.Paths
+
+ // This is a resource agnostic implementation of permissions for ACLs.
+ // Implementation could be different based on the resource type.
+ Permissions []Permission `json:"permissions,omitempty"`
+}
diff --git a/bundle/deploy/terraform/convert.go b/bundle/deploy/terraform/convert.go
index cd480c89..0956ea7b 100644
--- a/bundle/deploy/terraform/convert.go
+++ b/bundle/deploy/terraform/convert.go
@@ -161,6 +161,19 @@ func BundleToTerraform(config *config.Root) (*schema.Root, bool) {
}
}
+ for k, src := range config.Resources.ModelServingEndpoints {
+ noResources = false
+ var dst schema.ResourceModelServing
+ conv(src, &dst)
+ tfroot.Resource.ModelServing[k] = &dst
+
+ // Configure permissions for this resource.
+ if rp := convPermissions(src.Permissions); rp != nil {
+ rp.ServingEndpointId = fmt.Sprintf("${databricks_model_serving.%s.serving_endpoint_id}", k)
+ tfroot.Resource.Permissions["model_serving_"+k] = rp
+ }
+ }
+
return tfroot, noResources
}
@@ -196,6 +209,12 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.Experiments[resource.Name]
conv(tmp, &cur)
config.Resources.Experiments[resource.Name] = cur
+ case "databricks_model_serving":
+ var tmp schema.ResourceModelServing
+ conv(resource.AttributeValues, &tmp)
+ cur := config.Resources.ModelServingEndpoints[resource.Name]
+ conv(tmp, &cur)
+ config.Resources.ModelServingEndpoints[resource.Name] = cur
case "databricks_permissions":
// Ignore; no need to pull these back into the configuration.
default:
diff --git a/bundle/deploy/terraform/convert_test.go b/bundle/deploy/terraform/convert_test.go
index 34a65d70..ad626606 100644
--- a/bundle/deploy/terraform/convert_test.go
+++ b/bundle/deploy/terraform/convert_test.go
@@ -9,6 +9,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines"
+ "github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -292,3 +293,76 @@ func TestConvertExperimentPermissions(t *testing.T) {
assert.Equal(t, "CAN_READ", p.PermissionLevel)
}
+
+func TestConvertModelServing(t *testing.T) {
+ var src = resources.ModelServingEndpoint{
+ CreateServingEndpoint: &serving.CreateServingEndpoint{
+ Name: "name",
+ Config: serving.EndpointCoreConfigInput{
+ ServedModels: []serving.ServedModelInput{
+ {
+ ModelName: "model_name",
+ ModelVersion: "1",
+ ScaleToZeroEnabled: true,
+ WorkloadSize: "Small",
+ },
+ },
+ TrafficConfig: &serving.TrafficConfig{
+ Routes: []serving.Route{
+ {
+ ServedModelName: "model_name-1",
+ TrafficPercentage: 100,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ var config = config.Root{
+ Resources: config.Resources{
+ ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
+ "my_model_serving_endpoint": &src,
+ },
+ },
+ }
+
+ out, _ := BundleToTerraform(&config)
+ resource := out.Resource.ModelServing["my_model_serving_endpoint"]
+ assert.Equal(t, "name", resource.Name)
+ assert.Equal(t, "model_name", resource.Config.ServedModels[0].ModelName)
+ assert.Equal(t, "1", resource.Config.ServedModels[0].ModelVersion)
+ assert.Equal(t, true, resource.Config.ServedModels[0].ScaleToZeroEnabled)
+ assert.Equal(t, "Small", resource.Config.ServedModels[0].WorkloadSize)
+ assert.Equal(t, "model_name-1", resource.Config.TrafficConfig.Routes[0].ServedModelName)
+ assert.Equal(t, 100, resource.Config.TrafficConfig.Routes[0].TrafficPercentage)
+ assert.Nil(t, out.Data)
+}
+
+func TestConvertModelServingPermissions(t *testing.T) {
+ var src = resources.ModelServingEndpoint{
+ Permissions: []resources.Permission{
+ {
+ Level: "CAN_VIEW",
+ UserName: "jane@doe.com",
+ },
+ },
+ }
+
+ var config = config.Root{
+ Resources: config.Resources{
+ ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
+ "my_model_serving_endpoint": &src,
+ },
+ },
+ }
+
+ out, _ := BundleToTerraform(&config)
+ assert.NotEmpty(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].ServingEndpointId)
+ assert.Len(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl, 1)
+
+ p := out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl[0]
+ assert.Equal(t, "jane@doe.com", p.UserName)
+ assert.Equal(t, "CAN_VIEW", p.PermissionLevel)
+
+}
diff --git a/bundle/deploy/terraform/interpolate.go b/bundle/deploy/terraform/interpolate.go
index dd1dcbb8..ea3c99aa 100644
--- a/bundle/deploy/terraform/interpolate.go
+++ b/bundle/deploy/terraform/interpolate.go
@@ -25,6 +25,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "experiments":
path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
+ case "model_serving_endpoints":
+ path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
+ return fmt.Sprintf("${%s}", path), nil
default:
panic("TODO: " + parts[1])
}
diff --git a/bundle/schema/docs/bundle_descriptions.json b/bundle/schema/docs/bundle_descriptions.json
index 84f0492f..ffdb5629 100644
--- a/bundle/schema/docs/bundle_descriptions.json
+++ b/bundle/schema/docs/bundle_descriptions.json
@@ -1441,6 +1441,87 @@
}
}
},
+ "model_serving_endpoints": {
+ "description": "List of Model Serving Endpoints",
+ "additionalproperties": {
+ "description": "",
+ "properties": {
+ "name": {
+ "description": "The name of the model serving endpoint. This field is required and must be unique across a workspace. An endpoint name can consist of alphanumeric characters, dashes, and underscores. NOTE: Changing this name will delete the existing endpoint and create a new endpoint with the update name."
+ },
+ "permissions": {
+ "description": "",
+ "items": {
+ "description": "",
+ "properties": {
+ "group_name": {
+ "description": ""
+ },
+ "level": {
+ "description": ""
+ },
+ "service_principal_name": {
+ "description": ""
+ },
+ "user_name": {
+ "description": ""
+ }
+ }
+ }
+ },
+ "config": {
+ "description": "The model serving endpoint configuration.",
+ "properties": {
+ "description": "",
+ "properties": {
+ "served_models": {
+ "description": "Each block represents a served model for the endpoint to serve. A model serving endpoint can have up to 10 served models.",
+ "items": {
+ "description": "",
+ "properties" : {
+ "name": {
+ "description": "The name of a served model. It must be unique across an endpoint. If not specified, this field will default to modelname-modelversion. A served model name can consist of alphanumeric characters, dashes, and underscores."
+ },
+ "model_name": {
+ "description": "The name of the model in Databricks Model Registry to be served."
+ },
+ "model_version": {
+ "description": "The version of the model in Databricks Model Registry to be served."
+ },
+ "workload_size": {
+ "description": "The workload size of the served model. The workload size corresponds to a range of provisioned concurrency that the compute will autoscale between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are \"Small\" (4 - 4 provisioned concurrency), \"Medium\" (8 - 16 provisioned concurrency), and \"Large\" (16 - 64 provisioned concurrency)."
+ },
+ "scale_to_zero_enabled": {
+ "description": "Whether the compute resources for the served model should scale down to zero. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size will be 0."
+ }
+ }
+ }
+ },
+ "traffic_config": {
+ "description": "A single block represents the traffic split configuration amongst the served models.",
+ "properties": {
+ "routes": {
+ "description": "Each block represents a route that defines traffic to each served model. Each served_models block needs to have a corresponding routes block.",
+ "items": {
+ "description": "",
+ "properties": {
+ "served_model_name": {
+ "description": "The name of the served model this route configures traffic for. This needs to match the name of a served_models block."
+ },
+ "traffic_percentage": {
+ "description": "The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive."
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"pipelines": {
"description": "List of DLT pipelines",
"additionalproperties": {
diff --git a/bundle/schema/openapi.go b/bundle/schema/openapi.go
index b0d67657..1a8b76ed 100644
--- a/bundle/schema/openapi.go
+++ b/bundle/schema/openapi.go
@@ -210,6 +210,19 @@ func (reader *OpenapiReader) modelsDocs() (*Docs, error) {
return modelsDocs, nil
}
+func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
+ modelServingEndpointsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "serving.CreateServingEndpoint")
+ if err != nil {
+ return nil, err
+ }
+ modelServingEndpointsDocs := schemaToDocs(modelServingEndpointsSpecSchema)
+ modelServingEndpointsAllDocs := &Docs{
+ Description: "List of Model Serving Endpoints",
+ AdditionalProperties: modelServingEndpointsDocs,
+ }
+ return modelServingEndpointsAllDocs, nil
+}
+
func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
jobsDocs, err := reader.jobsDocs()
if err != nil {
@@ -227,14 +240,19 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
if err != nil {
return nil, err
}
+ modelServingEndpointsDocs, err := reader.modelServingEndpointsDocs()
+ if err != nil {
+ return nil, err
+ }
return &Docs{
Description: "Collection of Databricks resources to deploy.",
Properties: map[string]*Docs{
- "jobs": jobsDocs,
- "pipelines": pipelinesDocs,
- "experiments": experimentsDocs,
- "models": modelsDocs,
+ "jobs": jobsDocs,
+ "pipelines": pipelinesDocs,
+ "experiments": experimentsDocs,
+ "models": modelsDocs,
+ "model_serving_endpoints": modelServingEndpointsDocs,
},
}, nil
}
diff --git a/bundle/tests/model_serving_endpoint/databricks.yml b/bundle/tests/model_serving_endpoint/databricks.yml
new file mode 100644
index 00000000..e4fb54a1
--- /dev/null
+++ b/bundle/tests/model_serving_endpoint/databricks.yml
@@ -0,0 +1,38 @@
+resources:
+ model_serving_endpoints:
+ my_model_serving_endpoint:
+ name: "my-endpoint"
+ config:
+ served_models:
+ - model_name: "model-name"
+ model_version: "1"
+ workload_size: "Small"
+ scale_to_zero_enabled: true
+ traffic_config:
+ routes:
+ - served_model_name: "model-name-1"
+ traffic_percentage: 100
+ permissions:
+ - level: CAN_QUERY
+ group_name: users
+
+targets:
+ development:
+ mode: development
+ resources:
+ model_serving_endpoints:
+ my_model_serving_endpoint:
+ name: "my-dev-endpoint"
+
+ staging:
+ resources:
+ model_serving_endpoints:
+ my_model_serving_endpoint:
+ name: "my-staging-endpoint"
+
+ production:
+ mode: production
+ resources:
+ model_serving_endpoints:
+ my_model_serving_endpoint:
+ name: "my-prod-endpoint"
diff --git a/bundle/tests/model_serving_endpoint_test.go b/bundle/tests/model_serving_endpoint_test.go
new file mode 100644
index 00000000..bfa1a31b
--- /dev/null
+++ b/bundle/tests/model_serving_endpoint_test.go
@@ -0,0 +1,48 @@
+package config_tests
+
+import (
+ "path/filepath"
+ "testing"
+
+ "github.com/databricks/cli/bundle/config"
+ "github.com/databricks/cli/bundle/config/resources"
+ "github.com/stretchr/testify/assert"
+)
+
+func assertExpected(t *testing.T, p *resources.ModelServingEndpoint) {
+ assert.Equal(t, "model_serving_endpoint/databricks.yml", filepath.ToSlash(p.ConfigFilePath))
+ assert.Equal(t, "model-name", p.Config.ServedModels[0].ModelName)
+ assert.Equal(t, "1", p.Config.ServedModels[0].ModelVersion)
+ assert.Equal(t, "model-name-1", p.Config.TrafficConfig.Routes[0].ServedModelName)
+ assert.Equal(t, 100, p.Config.TrafficConfig.Routes[0].TrafficPercentage)
+ assert.Equal(t, "users", p.Permissions[0].GroupName)
+ assert.Equal(t, "CAN_QUERY", p.Permissions[0].Level)
+}
+
+func TestModelServingEndpointDevelopment(t *testing.T) {
+ b := loadTarget(t, "./model_serving_endpoint", "development")
+ assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
+ assert.Equal(t, b.Config.Bundle.Mode, config.Development)
+
+ p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
+ assert.Equal(t, "my-dev-endpoint", p.Name)
+ assertExpected(t, p)
+}
+
+func TestModelServingEndpointStaging(t *testing.T) {
+ b := loadTarget(t, "./model_serving_endpoint", "staging")
+ assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
+
+ p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
+ assert.Equal(t, "my-staging-endpoint", p.Name)
+ assertExpected(t, p)
+}
+
+func TestModelServingEndpointProduction(t *testing.T) {
+ b := loadTarget(t, "./model_serving_endpoint", "production")
+ assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
+
+ p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
+ assert.Equal(t, "my-prod-endpoint", p.Name)
+ assertExpected(t, p)
+}