From 58563b1ea94d7662ab305eda2af8e4e702c4f4aa Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Mon, 20 Mar 2023 21:28:43 +0100 Subject: [PATCH] Add resources for mlflow models and experiments (#263) Manually confirmed that both can be deployed. --- bundle/config/resources.go | 3 + bundle/config/resources/mlflow_experiment.go | 7 +++ bundle/config/resources/mlflow_model.go | 7 +++ bundle/deploy/terraform/convert.go | 24 ++++++++ bundle/deploy/terraform/convert_test.go | 58 ++++++++++++++++++++ bundle/deploy/terraform/interpolate.go | 6 ++ 6 files changed, 105 insertions(+) create mode 100644 bundle/config/resources/mlflow_experiment.go create mode 100644 bundle/config/resources/mlflow_model.go diff --git a/bundle/config/resources.go b/bundle/config/resources.go index 0c5396beb..7fa48357c 100644 --- a/bundle/config/resources.go +++ b/bundle/config/resources.go @@ -8,4 +8,7 @@ import ( type Resources struct { Jobs map[string]*resources.Job `json:"jobs,omitempty"` Pipelines map[string]*resources.Pipeline `json:"pipelines,omitempty"` + + Models map[string]*resources.MlflowModel `json:"models,omitempty"` + Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"` } diff --git a/bundle/config/resources/mlflow_experiment.go b/bundle/config/resources/mlflow_experiment.go new file mode 100644 index 000000000..c335821b3 --- /dev/null +++ b/bundle/config/resources/mlflow_experiment.go @@ -0,0 +1,7 @@ +package resources + +import "github.com/databricks/databricks-sdk-go/service/mlflow" + +type MlflowExperiment struct { + *mlflow.Experiment +} diff --git a/bundle/config/resources/mlflow_model.go b/bundle/config/resources/mlflow_model.go new file mode 100644 index 000000000..29354bf71 --- /dev/null +++ b/bundle/config/resources/mlflow_model.go @@ -0,0 +1,7 @@ +package resources + +import "github.com/databricks/databricks-sdk-go/service/mlflow" + +type MlflowModel struct { + *mlflow.RegisteredModel +} diff --git a/bundle/deploy/terraform/convert.go b/bundle/deploy/terraform/convert.go index 6eb4178ac..042d175c9 100644 --- a/bundle/deploy/terraform/convert.go +++ b/bundle/deploy/terraform/convert.go @@ -81,6 +81,18 @@ func BundleToTerraform(config *config.Root) *schema.Root { tfroot.Resource.Pipeline[k] = &dst } + for k, src := range config.Resources.Models { + var dst schema.ResourceMlflowModel + conv(src, &dst) + tfroot.Resource.MlflowModel[k] = &dst + } + + for k, src := range config.Resources.Experiments { + var dst schema.ResourceMlflowExperiment + conv(src, &dst) + tfroot.Resource.MlflowExperiment[k] = &dst + } + return tfroot } @@ -112,6 +124,18 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error { cur := config.Resources.Pipelines[resource.Name] conv(tmp, &cur) config.Resources.Pipelines[resource.Name] = cur + case "databricks_mlflow_model": + var tmp schema.ResourceMlflowModel + conv(resource.AttributeValues, &tmp) + cur := config.Resources.Models[resource.Name] + conv(tmp, &cur) + config.Resources.Models[resource.Name] = cur + case "databricks_mlflow_experiment": + var tmp schema.ResourceMlflowExperiment + conv(resource.AttributeValues, &tmp) + cur := config.Resources.Experiments[resource.Name] + conv(tmp, &cur) + config.Resources.Experiments[resource.Name] = cur default: return fmt.Errorf("missing mapping for %s", resource.Type) } diff --git a/bundle/deploy/terraform/convert_test.go b/bundle/deploy/terraform/convert_test.go index 3cc805489..60c91a2ca 100644 --- a/bundle/deploy/terraform/convert_test.go +++ b/bundle/deploy/terraform/convert_test.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/clusters" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/libraries" + "github.com/databricks/databricks-sdk-go/service/mlflow" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -79,3 +80,60 @@ func TestConvertJobTaskLibraries(t *testing.T) { require.Len(t, out.Resource.Job["my_job"].Task[0].Library, 1) assert.Equal(t, "mlflow", out.Resource.Job["my_job"].Task[0].Library[0].Pypi.Package) } + +func TestConvertModel(t *testing.T) { + var src = resources.MlflowModel{ + RegisteredModel: &mlflow.RegisteredModel{ + Name: "name", + Description: "description", + Tags: []mlflow.RegisteredModelTag{ + { + Key: "k1", + Value: "v1", + }, + { + Key: "k2", + Value: "v2", + }, + }, + }, + } + + var config = config.Root{ + Resources: config.Resources{ + Models: map[string]*resources.MlflowModel{ + "my_model": &src, + }, + }, + } + + out := BundleToTerraform(&config) + assert.Equal(t, "name", out.Resource.MlflowModel["my_model"].Name) + assert.Equal(t, "description", out.Resource.MlflowModel["my_model"].Description) + assert.Len(t, out.Resource.MlflowModel["my_model"].Tags, 2) + assert.Equal(t, "k1", out.Resource.MlflowModel["my_model"].Tags[0].Key) + assert.Equal(t, "v1", out.Resource.MlflowModel["my_model"].Tags[0].Value) + assert.Equal(t, "k2", out.Resource.MlflowModel["my_model"].Tags[1].Key) + assert.Equal(t, "v2", out.Resource.MlflowModel["my_model"].Tags[1].Value) + assert.Nil(t, out.Data) +} + +func TestConvertExperiment(t *testing.T) { + var src = resources.MlflowExperiment{ + Experiment: &mlflow.Experiment{ + Name: "name", + }, + } + + var config = config.Root{ + Resources: config.Resources{ + Experiments: map[string]*resources.MlflowExperiment{ + "my_experiment": &src, + }, + }, + } + + out := BundleToTerraform(&config) + assert.Equal(t, "name", out.Resource.MlflowExperiment["my_experiment"].Name) + assert.Nil(t, out.Data) +} diff --git a/bundle/deploy/terraform/interpolate.go b/bundle/deploy/terraform/interpolate.go index c7c8948c7..bf9edc568 100644 --- a/bundle/deploy/terraform/interpolate.go +++ b/bundle/deploy/terraform/interpolate.go @@ -19,6 +19,12 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri case "jobs": path = strings.Join(append([]string{"databricks_job"}, parts[2:]...), interpolation.Delimiter) return fmt.Sprintf("${%s}", path), nil + case "models": + path = strings.Join(append([]string{"databricks_mlflow_model"}, parts[2:]...), interpolation.Delimiter) + return fmt.Sprintf("${%s}", path), nil + case "experiments": + path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter) + return fmt.Sprintf("${%s}", path), nil default: panic("TODO: " + parts[1]) }