Support Model Serving Endpoints in bundles (#682)

## Changes
<!-- Summary of your changes that are easy to understand -->
Add Model Serving Endpoints to Databricks Bundles

## Tests
<!-- How is this tested? -->
Unit tests and manual testing via
https://github.com/databricks/bundle-examples-internal/pull/76
<img width="1570" alt="Screenshot 2023-08-28 at 7 46 23 PM"
src="https://github.com/databricks/cli/assets/87999496/7030ebd8-b0e2-4ad1-a9e3-5ff8454f1175">
<img width="747" alt="Screenshot 2023-08-28 at 7 47 01 PM"
src="https://github.com/databricks/cli/assets/87999496/fb9b54d7-54e2-43ce-9148-68fb620c809a">

Signed-off-by: Arpit Jasapara <arpit.jasapara@databricks.com>
This commit is contained in:
Arpit Jasapara 2023-09-07 14:54:31 -07:00 committed by GitHub
parent 5a14c7cb43
commit 50eaf16307
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 342 additions and 6 deletions

View File

@ -77,6 +77,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error {
r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName}) r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName})
} }
for i := range r.ModelServingEndpoints {
prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_"
r.ModelServingEndpoints[i].Name = prefix + r.ModelServingEndpoints[i].Name
// (model serving doesn't yet support tags)
}
return nil return nil
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -53,6 +54,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
Models: map[string]*resources.MlflowModel{ Models: map[string]*resources.MlflowModel{
"model1": {Model: &ml.Model{Name: "model1"}}, "model1": {Model: &ml.Model{Name: "model1"}},
}, },
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}},
},
}, },
}, },
} }
@ -69,6 +73,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) {
assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name) assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name)
assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name) assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name)
assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name) assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name)
assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key)
assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
} }
@ -82,6 +87,7 @@ func TestProcessTargetModeDefault(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
} }
func TestProcessTargetModeProduction(t *testing.T) { func TestProcessTargetModeProduction(t *testing.T) {
@ -109,6 +115,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions
bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions
bundle.Config.Resources.Models["model1"].Permissions = permissions bundle.Config.Resources.Models["model1"].Permissions = permissions
bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions
err = validateProductionMode(context.Background(), bundle, false) err = validateProductionMode(context.Background(), bundle, false)
require.NoError(t, err) require.NoError(t, err)
@ -116,6 +123,7 @@ func TestProcessTargetModeProduction(t *testing.T) {
assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name)
assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name)
assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name)
} }
func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) { func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) {

View File

@ -13,6 +13,7 @@ type Resources struct {
Models map[string]*resources.MlflowModel `json:"models,omitempty"` Models map[string]*resources.MlflowModel `json:"models,omitempty"`
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"`
} }
type UniqueResourceIdTracker struct { type UniqueResourceIdTracker struct {
@ -93,6 +94,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker,
tracker.Type[k] = "mlflow_experiment" tracker.Type[k] = "mlflow_experiment"
tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath
} }
for k := range r.ModelServingEndpoints {
if _, ok := tracker.Type[k]; ok {
return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)",
k,
tracker.Type[k],
tracker.ConfigPath[k],
"model_serving_endpoint",
r.ModelServingEndpoints[k].ConfigFilePath,
)
}
tracker.Type[k] = "model_serving_endpoint"
tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath
}
return tracker, nil return tracker, nil
} }
@ -112,6 +126,9 @@ func (r *Resources) SetConfigFilePath(path string) {
for _, e := range r.Experiments { for _, e := range r.Experiments {
e.ConfigFilePath = path e.ConfigFilePath = path
} }
for _, e := range r.ModelServingEndpoints {
e.ConfigFilePath = path
}
} }
// MergeJobClusters iterates over all jobs and merges their job clusters. // MergeJobClusters iterates over all jobs and merges their job clusters.

View File

@ -0,0 +1,24 @@
package resources
import (
"github.com/databricks/cli/bundle/config/paths"
"github.com/databricks/databricks-sdk-go/service/serving"
)
type ModelServingEndpoint struct {
// This represents the input args for terraform, and will get converted
// to a HCL representation for CRUD
*serving.CreateServingEndpoint
// This represents the id (ie serving_endpoint_id) that can be used
// as a reference in other resources. This value is returned by terraform.
ID string
// Local path where the bundle is defined. All bundle resources include
// this for interpolation purposes.
paths.Paths
// This is a resource agnostic implementation of permissions for ACLs.
// Implementation could be different based on the resource type.
Permissions []Permission `json:"permissions,omitempty"`
}

View File

@ -161,6 +161,19 @@ func BundleToTerraform(config *config.Root) (*schema.Root, bool) {
} }
} }
for k, src := range config.Resources.ModelServingEndpoints {
noResources = false
var dst schema.ResourceModelServing
conv(src, &dst)
tfroot.Resource.ModelServing[k] = &dst
// Configure permissions for this resource.
if rp := convPermissions(src.Permissions); rp != nil {
rp.ServingEndpointId = fmt.Sprintf("${databricks_model_serving.%s.serving_endpoint_id}", k)
tfroot.Resource.Permissions["model_serving_"+k] = rp
}
}
return tfroot, noResources return tfroot, noResources
} }
@ -196,6 +209,12 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.Experiments[resource.Name] cur := config.Resources.Experiments[resource.Name]
conv(tmp, &cur) conv(tmp, &cur)
config.Resources.Experiments[resource.Name] = cur config.Resources.Experiments[resource.Name] = cur
case "databricks_model_serving":
var tmp schema.ResourceModelServing
conv(resource.AttributeValues, &tmp)
cur := config.Resources.ModelServingEndpoints[resource.Name]
conv(tmp, &cur)
config.Resources.ModelServingEndpoints[resource.Name] = cur
case "databricks_permissions": case "databricks_permissions":
// Ignore; no need to pull these back into the configuration. // Ignore; no need to pull these back into the configuration.
default: default:

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/ml"
"github.com/databricks/databricks-sdk-go/service/pipelines" "github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -292,3 +293,76 @@ func TestConvertExperimentPermissions(t *testing.T) {
assert.Equal(t, "CAN_READ", p.PermissionLevel) assert.Equal(t, "CAN_READ", p.PermissionLevel)
} }
func TestConvertModelServing(t *testing.T) {
var src = resources.ModelServingEndpoint{
CreateServingEndpoint: &serving.CreateServingEndpoint{
Name: "name",
Config: serving.EndpointCoreConfigInput{
ServedModels: []serving.ServedModelInput{
{
ModelName: "model_name",
ModelVersion: "1",
ScaleToZeroEnabled: true,
WorkloadSize: "Small",
},
},
TrafficConfig: &serving.TrafficConfig{
Routes: []serving.Route{
{
ServedModelName: "model_name-1",
TrafficPercentage: 100,
},
},
},
},
},
}
var config = config.Root{
Resources: config.Resources{
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"my_model_serving_endpoint": &src,
},
},
}
out, _ := BundleToTerraform(&config)
resource := out.Resource.ModelServing["my_model_serving_endpoint"]
assert.Equal(t, "name", resource.Name)
assert.Equal(t, "model_name", resource.Config.ServedModels[0].ModelName)
assert.Equal(t, "1", resource.Config.ServedModels[0].ModelVersion)
assert.Equal(t, true, resource.Config.ServedModels[0].ScaleToZeroEnabled)
assert.Equal(t, "Small", resource.Config.ServedModels[0].WorkloadSize)
assert.Equal(t, "model_name-1", resource.Config.TrafficConfig.Routes[0].ServedModelName)
assert.Equal(t, 100, resource.Config.TrafficConfig.Routes[0].TrafficPercentage)
assert.Nil(t, out.Data)
}
func TestConvertModelServingPermissions(t *testing.T) {
var src = resources.ModelServingEndpoint{
Permissions: []resources.Permission{
{
Level: "CAN_VIEW",
UserName: "jane@doe.com",
},
},
}
var config = config.Root{
Resources: config.Resources{
ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{
"my_model_serving_endpoint": &src,
},
},
}
out, _ := BundleToTerraform(&config)
assert.NotEmpty(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].ServingEndpointId)
assert.Len(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl, 1)
p := out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl[0]
assert.Equal(t, "jane@doe.com", p.UserName)
assert.Equal(t, "CAN_VIEW", p.PermissionLevel)
}

View File

@ -25,6 +25,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "experiments": case "experiments":
path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter) path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil return fmt.Sprintf("${%s}", path), nil
case "model_serving_endpoints":
path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
default: default:
panic("TODO: " + parts[1]) panic("TODO: " + parts[1])
} }

View File

@ -1441,6 +1441,87 @@
} }
} }
}, },
"model_serving_endpoints": {
"description": "List of Model Serving Endpoints",
"additionalproperties": {
"description": "",
"properties": {
"name": {
"description": "The name of the model serving endpoint. This field is required and must be unique across a workspace. An endpoint name can consist of alphanumeric characters, dashes, and underscores. NOTE: Changing this name will delete the existing endpoint and create a new endpoint with the update name."
},
"permissions": {
"description": "",
"items": {
"description": "",
"properties": {
"group_name": {
"description": ""
},
"level": {
"description": ""
},
"service_principal_name": {
"description": ""
},
"user_name": {
"description": ""
}
}
}
},
"config": {
"description": "The model serving endpoint configuration.",
"properties": {
"description": "",
"properties": {
"served_models": {
"description": "Each block represents a served model for the endpoint to serve. A model serving endpoint can have up to 10 served models.",
"items": {
"description": "",
"properties" : {
"name": {
"description": "The name of a served model. It must be unique across an endpoint. If not specified, this field will default to modelname-modelversion. A served model name can consist of alphanumeric characters, dashes, and underscores."
},
"model_name": {
"description": "The name of the model in Databricks Model Registry to be served."
},
"model_version": {
"description": "The version of the model in Databricks Model Registry to be served."
},
"workload_size": {
"description": "The workload size of the served model. The workload size corresponds to a range of provisioned concurrency that the compute will autoscale between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are \"Small\" (4 - 4 provisioned concurrency), \"Medium\" (8 - 16 provisioned concurrency), and \"Large\" (16 - 64 provisioned concurrency)."
},
"scale_to_zero_enabled": {
"description": "Whether the compute resources for the served model should scale down to zero. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size will be 0."
}
}
}
},
"traffic_config": {
"description": "A single block represents the traffic split configuration amongst the served models.",
"properties": {
"routes": {
"description": "Each block represents a route that defines traffic to each served model. Each served_models block needs to have a corresponding routes block.",
"items": {
"description": "",
"properties": {
"served_model_name": {
"description": "The name of the served model this route configures traffic for. This needs to match the name of a served_models block."
},
"traffic_percentage": {
"description": "The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive."
}
}
}
}
}
}
}
}
}
}
}
},
"pipelines": { "pipelines": {
"description": "List of DLT pipelines", "description": "List of DLT pipelines",
"additionalproperties": { "additionalproperties": {

View File

@ -210,6 +210,19 @@ func (reader *OpenapiReader) modelsDocs() (*Docs, error) {
return modelsDocs, nil return modelsDocs, nil
} }
func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) {
modelServingEndpointsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "serving.CreateServingEndpoint")
if err != nil {
return nil, err
}
modelServingEndpointsDocs := schemaToDocs(modelServingEndpointsSpecSchema)
modelServingEndpointsAllDocs := &Docs{
Description: "List of Model Serving Endpoints",
AdditionalProperties: modelServingEndpointsDocs,
}
return modelServingEndpointsAllDocs, nil
}
func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
jobsDocs, err := reader.jobsDocs() jobsDocs, err := reader.jobsDocs()
if err != nil { if err != nil {
@ -227,6 +240,10 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
modelServingEndpointsDocs, err := reader.modelServingEndpointsDocs()
if err != nil {
return nil, err
}
return &Docs{ return &Docs{
Description: "Collection of Databricks resources to deploy.", Description: "Collection of Databricks resources to deploy.",
@ -235,6 +252,7 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) {
"pipelines": pipelinesDocs, "pipelines": pipelinesDocs,
"experiments": experimentsDocs, "experiments": experimentsDocs,
"models": modelsDocs, "models": modelsDocs,
"model_serving_endpoints": modelServingEndpointsDocs,
}, },
}, nil }, nil
} }

View File

@ -0,0 +1,38 @@
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-endpoint"
config:
served_models:
- model_name: "model-name"
model_version: "1"
workload_size: "Small"
scale_to_zero_enabled: true
traffic_config:
routes:
- served_model_name: "model-name-1"
traffic_percentage: 100
permissions:
- level: CAN_QUERY
group_name: users
targets:
development:
mode: development
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-dev-endpoint"
staging:
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-staging-endpoint"
production:
mode: production
resources:
model_serving_endpoints:
my_model_serving_endpoint:
name: "my-prod-endpoint"

View File

@ -0,0 +1,48 @@
package config_tests
import (
"path/filepath"
"testing"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/stretchr/testify/assert"
)
func assertExpected(t *testing.T, p *resources.ModelServingEndpoint) {
assert.Equal(t, "model_serving_endpoint/databricks.yml", filepath.ToSlash(p.ConfigFilePath))
assert.Equal(t, "model-name", p.Config.ServedModels[0].ModelName)
assert.Equal(t, "1", p.Config.ServedModels[0].ModelVersion)
assert.Equal(t, "model-name-1", p.Config.TrafficConfig.Routes[0].ServedModelName)
assert.Equal(t, 100, p.Config.TrafficConfig.Routes[0].TrafficPercentage)
assert.Equal(t, "users", p.Permissions[0].GroupName)
assert.Equal(t, "CAN_QUERY", p.Permissions[0].Level)
}
func TestModelServingEndpointDevelopment(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "development")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
assert.Equal(t, b.Config.Bundle.Mode, config.Development)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-dev-endpoint", p.Name)
assertExpected(t, p)
}
func TestModelServingEndpointStaging(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "staging")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-staging-endpoint", p.Name)
assertExpected(t, p)
}
func TestModelServingEndpointProduction(t *testing.T) {
b := loadTarget(t, "./model_serving_endpoint", "production")
assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1)
p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"]
assert.Equal(t, "my-prod-endpoint", p.Name)
assertExpected(t, p)
}