Add resources for mlflow models and experiments (#263)

Manually confirmed that both can be deployed.
This commit is contained in:
Pieter Noordhuis 2023-03-20 21:28:43 +01:00 committed by GitHub
parent 077ab8b864
commit 58563b1ea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 0 deletions

View File

@ -8,4 +8,7 @@ import (
type Resources struct {
Jobs map[string]*resources.Job `json:"jobs,omitempty"`
Pipelines map[string]*resources.Pipeline `json:"pipelines,omitempty"`
Models map[string]*resources.MlflowModel `json:"models,omitempty"`
Experiments map[string]*resources.MlflowExperiment `json:"experiments,omitempty"`
}

View File

@ -0,0 +1,7 @@
package resources
import "github.com/databricks/databricks-sdk-go/service/mlflow"
type MlflowExperiment struct {
*mlflow.Experiment
}

View File

@ -0,0 +1,7 @@
package resources
import "github.com/databricks/databricks-sdk-go/service/mlflow"
type MlflowModel struct {
*mlflow.RegisteredModel
}

View File

@ -81,6 +81,18 @@ func BundleToTerraform(config *config.Root) *schema.Root {
tfroot.Resource.Pipeline[k] = &dst
}
for k, src := range config.Resources.Models {
var dst schema.ResourceMlflowModel
conv(src, &dst)
tfroot.Resource.MlflowModel[k] = &dst
}
for k, src := range config.Resources.Experiments {
var dst schema.ResourceMlflowExperiment
conv(src, &dst)
tfroot.Resource.MlflowExperiment[k] = &dst
}
return tfroot
}
@ -112,6 +124,18 @@ func TerraformToBundle(state *tfjson.State, config *config.Root) error {
cur := config.Resources.Pipelines[resource.Name]
conv(tmp, &cur)
config.Resources.Pipelines[resource.Name] = cur
case "databricks_mlflow_model":
var tmp schema.ResourceMlflowModel
conv(resource.AttributeValues, &tmp)
cur := config.Resources.Models[resource.Name]
conv(tmp, &cur)
config.Resources.Models[resource.Name] = cur
case "databricks_mlflow_experiment":
var tmp schema.ResourceMlflowExperiment
conv(resource.AttributeValues, &tmp)
cur := config.Resources.Experiments[resource.Name]
conv(tmp, &cur)
config.Resources.Experiments[resource.Name] = cur
default:
return fmt.Errorf("missing mapping for %s", resource.Type)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/databricks/databricks-sdk-go/service/clusters"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/libraries"
"github.com/databricks/databricks-sdk-go/service/mlflow"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -79,3 +80,60 @@ func TestConvertJobTaskLibraries(t *testing.T) {
require.Len(t, out.Resource.Job["my_job"].Task[0].Library, 1)
assert.Equal(t, "mlflow", out.Resource.Job["my_job"].Task[0].Library[0].Pypi.Package)
}
func TestConvertModel(t *testing.T) {
var src = resources.MlflowModel{
RegisteredModel: &mlflow.RegisteredModel{
Name: "name",
Description: "description",
Tags: []mlflow.RegisteredModelTag{
{
Key: "k1",
Value: "v1",
},
{
Key: "k2",
Value: "v2",
},
},
},
}
var config = config.Root{
Resources: config.Resources{
Models: map[string]*resources.MlflowModel{
"my_model": &src,
},
},
}
out := BundleToTerraform(&config)
assert.Equal(t, "name", out.Resource.MlflowModel["my_model"].Name)
assert.Equal(t, "description", out.Resource.MlflowModel["my_model"].Description)
assert.Len(t, out.Resource.MlflowModel["my_model"].Tags, 2)
assert.Equal(t, "k1", out.Resource.MlflowModel["my_model"].Tags[0].Key)
assert.Equal(t, "v1", out.Resource.MlflowModel["my_model"].Tags[0].Value)
assert.Equal(t, "k2", out.Resource.MlflowModel["my_model"].Tags[1].Key)
assert.Equal(t, "v2", out.Resource.MlflowModel["my_model"].Tags[1].Value)
assert.Nil(t, out.Data)
}
func TestConvertExperiment(t *testing.T) {
var src = resources.MlflowExperiment{
Experiment: &mlflow.Experiment{
Name: "name",
},
}
var config = config.Root{
Resources: config.Resources{
Experiments: map[string]*resources.MlflowExperiment{
"my_experiment": &src,
},
},
}
out := BundleToTerraform(&config)
assert.Equal(t, "name", out.Resource.MlflowExperiment["my_experiment"].Name)
assert.Nil(t, out.Data)
}

View File

@ -19,6 +19,12 @@ func interpolateTerraformResourceIdentifiers(path string, lookup map[string]stri
case "jobs":
path = strings.Join(append([]string{"databricks_job"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
case "models":
path = strings.Join(append([]string{"databricks_mlflow_model"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
case "experiments":
path = strings.Join(append([]string{"databricks_mlflow_experiment"}, parts[2:]...), interpolation.Delimiter)
return fmt.Sprintf("${%s}", path), nil
default:
panic("TODO: " + parts[1])
}