diff --git a/bundle/config/mutator/process_target_mode.go b/bundle/config/mutator/process_target_mode.go index 06ae7b85..93149ad0 100644 --- a/bundle/config/mutator/process_target_mode.go +++ b/bundle/config/mutator/process_target_mode.go @@ -77,6 +77,12 @@ func transformDevelopmentMode(b *bundle.Bundle) error { r.Experiments[i].Tags = append(r.Experiments[i].Tags, ml.ExperimentTag{Key: "dev", Value: b.Config.Workspace.CurrentUser.DisplayName}) } + for i := range r.ModelServingEndpoints { + prefix = "dev_" + b.Config.Workspace.CurrentUser.ShortName + "_" + r.ModelServingEndpoints[i].Name = prefix + r.ModelServingEndpoints[i].Name + // (model serving doesn't yet support tags) + } + return nil } diff --git a/bundle/config/mutator/process_target_mode_test.go b/bundle/config/mutator/process_target_mode_test.go index 489632e1..4ea33c70 100644 --- a/bundle/config/mutator/process_target_mode_test.go +++ b/bundle/config/mutator/process_target_mode_test.go @@ -13,6 +13,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/databricks/databricks-sdk-go/service/serving" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -53,6 +54,9 @@ func mockBundle(mode config.Mode) *bundle.Bundle { Models: map[string]*resources.MlflowModel{ "model1": {Model: &ml.Model{Name: "model1"}}, }, + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "servingendpoint1": {CreateServingEndpoint: &serving.CreateServingEndpoint{Name: "servingendpoint1"}}, + }, }, }, } @@ -69,6 +73,7 @@ func TestProcessTargetModeDevelopment(t *testing.T) { assert.Equal(t, "/Users/lennart.kats@databricks.com/[dev lennart] experiment1", bundle.Config.Resources.Experiments["experiment1"].Name) assert.Equal(t, "[dev lennart] experiment2", bundle.Config.Resources.Experiments["experiment2"].Name) assert.Equal(t, "[dev lennart] model1", bundle.Config.Resources.Models["model1"].Name) + assert.Equal(t, "dev_lennart_servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) assert.Equal(t, "dev", bundle.Config.Resources.Experiments["experiment1"].Experiment.Tags[0].Key) assert.True(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) } @@ -82,6 +87,7 @@ func TestProcessTargetModeDefault(t *testing.T) { assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) + assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) } func TestProcessTargetModeProduction(t *testing.T) { @@ -109,6 +115,7 @@ func TestProcessTargetModeProduction(t *testing.T) { bundle.Config.Resources.Experiments["experiment1"].Permissions = permissions bundle.Config.Resources.Experiments["experiment2"].Permissions = permissions bundle.Config.Resources.Models["model1"].Permissions = permissions + bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Permissions = permissions err = validateProductionMode(context.Background(), bundle, false) require.NoError(t, err) @@ -116,6 +123,7 @@ func TestProcessTargetModeProduction(t *testing.T) { assert.Equal(t, "job1", bundle.Config.Resources.Jobs["job1"].Name) assert.Equal(t, "pipeline1", bundle.Config.Resources.Pipelines["pipeline1"].Name) assert.False(t, bundle.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) + assert.Equal(t, "servingendpoint1", bundle.Config.Resources.ModelServingEndpoints["servingendpoint1"].Name) } func TestProcessTargetModeProductionOkForPrincipal(t *testing.T) { diff --git a/bundle/config/resources.go b/bundle/config/resources.go index 5d47b918..c239b510 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -11,8 +11,9 @@ type Resources struct { Jobs map[string]*resources.Job `json:"jobs,omitempty"` Pipelines map[string]*resources.Pipeline `json:"pipelines,omitempty"` - Models map[string]*resources.MlflowModel `json:"models,omitempty"` - Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` + Models map[string]*resources.MlflowModel `json:"models,omitempty"` + Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` + ModelServingEndpoints map[string]*resources.ModelServingEndpoint `json:"model_serving_endpoints,omitempty"` } type UniqueResourceIdTracker struct { @@ -93,6 +94,19 @@ func (r *Resources) VerifyUniqueResourceIdentifiers() (*UniqueResourceIdTracker, tracker.Type[k] = "mlflow_experiment" tracker.ConfigPath[k] = r.Experiments[k].ConfigFilePath } + for k := range r.ModelServingEndpoints { + if _, ok := tracker.Type[k]; ok { + return tracker, fmt.Errorf("multiple resources named %s (%s at %s, %s at %s)", + k, + tracker.Type[k], + tracker.ConfigPath[k], + "model_serving_endpoint", + r.ModelServingEndpoints[k].ConfigFilePath, + ) + } + tracker.Type[k] = "model_serving_endpoint" + tracker.ConfigPath[k] = r.ModelServingEndpoints[k].ConfigFilePath + } return tracker, nil } @@ -112,6 +126,9 @@ func (r *Resources) SetConfigFilePath(path string) { for _, e := range r.Experiments { e.ConfigFilePath = path } + for _, e := range r.ModelServingEndpoints { + e.ConfigFilePath = path + } } // MergeJobClusters iterates over all jobs and merges their job clusters. diff --git a/bundle/config/resources/model_serving_endpoint.go b/bundle/config/resources/model_serving_endpoint.go new file mode 100644 index 00000000..dccecaa6 --- /dev/null +++ b/bundle/config/resources/model_serving_endpoint.go @@ -0,0 +1,24 @@ +package resources + +import ( + "github.com/databricks/cli/bundle/config/paths" + "github.com/databricks/databricks-sdk-go/service/serving" +) + +type ModelServingEndpoint struct { + // This represents the input args for terraform, and will get converted + // to a HCL representation for CRUD + *serving.CreateServingEndpoint + + // This represents the id (ie serving_endpoint_id) that can be used + // as a reference in other resources. This value is returned by terraform. + ID string + + // Local path where the bundle is defined. All bundle resources include + // this for interpolation purposes. + paths.Paths + + // This is a resource agnostic implementation of permissions for ACLs. + // Implementation could be different based on the resource type. + Permissions []Permission `json:"permissions,omitempty"` +} diff --git a/bundle/deploy/terraform/convert.go b/bundle/deploy/terraform/convert.go index cd480c89..0956ea7b 100644 --- a/bundle/deploy/terraform/convert.go +++ b/bundle/deploy/terraform/convert.go @@ -161,6 +161,19 @@ func BundleToTerraform(config *config.Root) (*schema.Root, bool) { } } + for k, src := range config.Resources.ModelServingEndpoints { + noResources = false + var dst schema.ResourceModelServing + conv(src, &dst) + tfroot.Resource.ModelServing[k] = &dst + + // Configure permissions for this resource. + if rp := convPermissions(src.Permissions); rp != nil { + rp.ServingEndpointId = fmt.Sprintf("${databricks_model_serving.%s.serving_endpoint_id}", k) + tfroot.Resource.Permissions["model_serving_"+k] = rp + } + } + return tfroot, noResources } @@ -196,6 +209,12 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error { cur := config.Resources.Experiments[resource.Name] conv(tmp, &cur) config.Resources.Experiments[resource.Name] = cur + case "databricks_model_serving": + var tmp schema.ResourceModelServing + conv(resource.AttributeValues, &tmp) + cur := config.Resources.ModelServingEndpoints[resource.Name] + conv(tmp, &cur) + config.Resources.ModelServingEndpoints[resource.Name] = cur case "databricks_permissions": // Ignore; no need to pull these back into the configuration. default: diff --git a/bundle/deploy/terraform/convert_test.go b/bundle/deploy/terraform/convert_test.go index 34a65d70..ad626606 100644 --- a/bundle/deploy/terraform/convert_test.go +++ b/bundle/deploy/terraform/convert_test.go @@ -9,6 +9,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/ml" "github.com/databricks/databricks-sdk-go/service/pipelines" + "github.com/databricks/databricks-sdk-go/service/serving" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -292,3 +293,76 @@ func TestConvertExperimentPermissions(t *testing.T) { assert.Equal(t, "CAN_READ", p.PermissionLevel) } + +func TestConvertModelServing(t *testing.T) { + var src = resources.ModelServingEndpoint{ + CreateServingEndpoint: &serving.CreateServingEndpoint{ + Name: "name", + Config: serving.EndpointCoreConfigInput{ + ServedModels: []serving.ServedModelInput{ + { + ModelName: "model_name", + ModelVersion: "1", + ScaleToZeroEnabled: true, + WorkloadSize: "Small", + }, + }, + TrafficConfig: &serving.TrafficConfig{ + Routes: []serving.Route{ + { + ServedModelName: "model_name-1", + TrafficPercentage: 100, + }, + }, + }, + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "my_model_serving_endpoint": &src, + }, + }, + } + + out, _ := BundleToTerraform(&config) + resource := out.Resource.ModelServing["my_model_serving_endpoint"] + assert.Equal(t, "name", resource.Name) + assert.Equal(t, "model_name", resource.Config.ServedModels[0].ModelName) + assert.Equal(t, "1", resource.Config.ServedModels[0].ModelVersion) + assert.Equal(t, true, resource.Config.ServedModels[0].ScaleToZeroEnabled) + assert.Equal(t, "Small", resource.Config.ServedModels[0].WorkloadSize) + assert.Equal(t, "model_name-1", resource.Config.TrafficConfig.Routes[0].ServedModelName) + assert.Equal(t, 100, resource.Config.TrafficConfig.Routes[0].TrafficPercentage) + assert.Nil(t, out.Data) +} + +func TestConvertModelServingPermissions(t *testing.T) { + var src = resources.ModelServingEndpoint{ + Permissions: []resources.Permission{ + { + Level: "CAN_VIEW", + UserName: "jane@doe.com", + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + ModelServingEndpoints: map[string]*resources.ModelServingEndpoint{ + "my_model_serving_endpoint": &src, + }, + }, + } + + out, _ := BundleToTerraform(&config) + assert.NotEmpty(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].ServingEndpointId) + assert.Len(t, out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl, 1) + + p := out.Resource.Permissions["model_serving_my_model_serving_endpoint"].AccessControl[0] + assert.Equal(t, "jane@doe.com", p.UserName) + assert.Equal(t, "CAN_VIEW", p.PermissionLevel) + +} diff --git a/bundle/deploy/terraform/interpolate.go b/bundle/deploy/terraform/interpolate.go index dd1dcbb8..ea3c99aa 100644 --- a/bundle/deploy/terraform/interpolate.go +++ b/bundle/deploy/terraform/interpolate.go @@ -25,6 +25,9 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri case "experiments": path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter) return fmt.Sprintf("${%s}", path), nil + case "model_serving_endpoints": + path = strings.Join(append([]string{"databricks_model_serving"}, parts[2:]...), interpolation.Delimiter) + return fmt.Sprintf("${%s}", path), nil default: panic("TODO: " + parts[1]) } diff --git a/bundle/schema/docs/bundle_descriptions.json b/bundle/schema/docs/bundle_descriptions.json index 84f0492f..ffdb5629 100644 --- a/bundle/schema/docs/bundle_descriptions.json +++ b/bundle/schema/docs/bundle_descriptions.json @@ -1441,6 +1441,87 @@ } } }, + "model_serving_endpoints": { + "description": "List of Model Serving Endpoints", + "additionalproperties": { + "description": "", + "properties": { + "name": { + "description": "The name of the model serving endpoint. This field is required and must be unique across a workspace. An endpoint name can consist of alphanumeric characters, dashes, and underscores. NOTE: Changing this name will delete the existing endpoint and create a new endpoint with the update name." + }, + "permissions": { + "description": "", + "items": { + "description": "", + "properties": { + "group_name": { + "description": "" + }, + "level": { + "description": "" + }, + "service_principal_name": { + "description": "" + }, + "user_name": { + "description": "" + } + } + } + }, + "config": { + "description": "The model serving endpoint configuration.", + "properties": { + "description": "", + "properties": { + "served_models": { + "description": "Each block represents a served model for the endpoint to serve. A model serving endpoint can have up to 10 served models.", + "items": { + "description": "", + "properties" : { + "name": { + "description": "The name of a served model. It must be unique across an endpoint. If not specified, this field will default to modelname-modelversion. A served model name can consist of alphanumeric characters, dashes, and underscores." + }, + "model_name": { + "description": "The name of the model in Databricks Model Registry to be served." + }, + "model_version": { + "description": "The version of the model in Databricks Model Registry to be served." + }, + "workload_size": { + "description": "The workload size of the served model. The workload size corresponds to a range of provisioned concurrency that the compute will autoscale between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are \"Small\" (4 - 4 provisioned concurrency), \"Medium\" (8 - 16 provisioned concurrency), and \"Large\" (16 - 64 provisioned concurrency)." + }, + "scale_to_zero_enabled": { + "description": "Whether the compute resources for the served model should scale down to zero. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size will be 0." + } + } + } + }, + "traffic_config": { + "description": "A single block represents the traffic split configuration amongst the served models.", + "properties": { + "routes": { + "description": "Each block represents a route that defines traffic to each served model. Each served_models block needs to have a corresponding routes block.", + "items": { + "description": "", + "properties": { + "served_model_name": { + "description": "The name of the served model this route configures traffic for. This needs to match the name of a served_models block." + }, + "traffic_percentage": { + "description": "The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive." + } + } + } + } + } + } + } + } + } + } + } + }, "pipelines": { "description": "List of DLT pipelines", "additionalproperties": { diff --git a/bundle/schema/openapi.go b/bundle/schema/openapi.go index b0d67657..1a8b76ed 100644 --- a/bundle/schema/openapi.go +++ b/bundle/schema/openapi.go @@ -210,6 +210,19 @@ func (reader *OpenapiReader) modelsDocs() (*Docs, error) { return modelsDocs, nil } +func (reader *OpenapiReader) modelServingEndpointsDocs() (*Docs, error) { + modelServingEndpointsSpecSchema, err := reader.readResolvedSchema(SchemaPathPrefix + "serving.CreateServingEndpoint") + if err != nil { + return nil, err + } + modelServingEndpointsDocs := schemaToDocs(modelServingEndpointsSpecSchema) + modelServingEndpointsAllDocs := &Docs{ + Description: "List of Model Serving Endpoints", + AdditionalProperties: modelServingEndpointsDocs, + } + return modelServingEndpointsAllDocs, nil +} + func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { jobsDocs, err := reader.jobsDocs() if err != nil { @@ -227,14 +240,19 @@ func (reader *OpenapiReader) ResourcesDocs() (*Docs, error) { if err != nil { return nil, err } + modelServingEndpointsDocs, err := reader.modelServingEndpointsDocs() + if err != nil { + return nil, err + } return &Docs{ Description: "Collection of Databricks resources to deploy.", Properties: map[string]*Docs{ - "jobs": jobsDocs, - "pipelines": pipelinesDocs, - "experiments": experimentsDocs, - "models": modelsDocs, + "jobs": jobsDocs, + "pipelines": pipelinesDocs, + "experiments": experimentsDocs, + "models": modelsDocs, + "model_serving_endpoints": modelServingEndpointsDocs, }, }, nil } diff --git a/bundle/tests/model_serving_endpoint/databricks.yml b/bundle/tests/model_serving_endpoint/databricks.yml new file mode 100644 index 00000000..e4fb54a1 --- /dev/null +++ b/bundle/tests/model_serving_endpoint/databricks.yml @@ -0,0 +1,38 @@ +resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-endpoint" + config: + served_models: + - model_name: "model-name" + model_version: "1" + workload_size: "Small" + scale_to_zero_enabled: true + traffic_config: + routes: + - served_model_name: "model-name-1" + traffic_percentage: 100 + permissions: + - level: CAN_QUERY + group_name: users + +targets: + development: + mode: development + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-dev-endpoint" + + staging: + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-staging-endpoint" + + production: + mode: production + resources: + model_serving_endpoints: + my_model_serving_endpoint: + name: "my-prod-endpoint" diff --git a/bundle/tests/model_serving_endpoint_test.go b/bundle/tests/model_serving_endpoint_test.go new file mode 100644 index 00000000..bfa1a31b --- /dev/null +++ b/bundle/tests/model_serving_endpoint_test.go @@ -0,0 +1,48 @@ +package config_tests + +import ( + "path/filepath" + "testing" + + "github.com/databricks/cli/bundle/config" + "github.com/databricks/cli/bundle/config/resources" + "github.com/stretchr/testify/assert" +) + +func assertExpected(t *testing.T, p *resources.ModelServingEndpoint) { + assert.Equal(t, "model_serving_endpoint/databricks.yml", filepath.ToSlash(p.ConfigFilePath)) + assert.Equal(t, "model-name", p.Config.ServedModels[0].ModelName) + assert.Equal(t, "1", p.Config.ServedModels[0].ModelVersion) + assert.Equal(t, "model-name-1", p.Config.TrafficConfig.Routes[0].ServedModelName) + assert.Equal(t, 100, p.Config.TrafficConfig.Routes[0].TrafficPercentage) + assert.Equal(t, "users", p.Permissions[0].GroupName) + assert.Equal(t, "CAN_QUERY", p.Permissions[0].Level) +} + +func TestModelServingEndpointDevelopment(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "development") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + assert.Equal(t, b.Config.Bundle.Mode, config.Development) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-dev-endpoint", p.Name) + assertExpected(t, p) +} + +func TestModelServingEndpointStaging(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "staging") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-staging-endpoint", p.Name) + assertExpected(t, p) +} + +func TestModelServingEndpointProduction(t *testing.T) { + b := loadTarget(t, "./model_serving_endpoint", "production") + assert.Len(t, b.Config.Resources.ModelServingEndpoints, 1) + + p := b.Config.Resources.ModelServingEndpoints["my_model_serving_endpoint"] + assert.Equal(t, "my-prod-endpoint", p.Name) + assertExpected(t, p) +}