Store project object in context.Context instead of global (#61)

* Load project root from `BRICKS_ROOT` environment variable
* Rename project.Project -> project.Config
* Rename project.inner -> project.project
* Upgrade cobra to 1.5.0 for cmd.SetContext
This commit is contained in:
Pieter Noordhuis 2022-09-16 11:06:58 +02:00 committed by GitHub
parent 836ab58473
commit a7701cc8f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 274 additions and 132 deletions

View File

@ -55,7 +55,7 @@ func (cfg *Configs) loadInteractive(cmd *cobra.Command) error {
Default: func(res prompt.Results) string {
return cfg.Host
},
Callback: func(ans prompt.Answer, prj *project.Project, res prompt.Results) {
Callback: func(ans prompt.Answer, config *project.Config, res prompt.Results) {
cfg.Host = ans.Value
},
})
@ -70,7 +70,7 @@ func (cfg *Configs) loadInteractive(cmd *cobra.Command) error {
Default: func(res prompt.Results) string {
return cfg.Token
},
Callback: func(ans prompt.Answer, prj *project.Project, res prompt.Results) {
Callback: func(ans prompt.Answer, config *project.Config, res prompt.Results) {
cfg.Token = ans.Value
},
})

View File

@ -13,8 +13,10 @@ var lsCmd = &cobra.Command{
Short: "Lists files",
Long: `Lists files`,
Args: cobra.ExactArgs(1),
PreRunE: project.Configure,
Run: func(cmd *cobra.Command, args []string) {
wsc := project.Current.WorkspacesClient()
wsc := project.Get(cmd.Context()).WorkspacesClient()
listStatusResponse, err := wsc.Dbfs.ListByPath(cmd.Context(), args[0])
if err != nil {
panic(err)

View File

@ -38,8 +38,8 @@ var initCmd = &cobra.Command{
Default: func(res prompt.Results) string {
return path.Base(wd)
},
Callback: func(ans prompt.Answer, prj *project.Project, res prompt.Results) {
prj.Name = ans.Value
Callback: func(ans prompt.Answer, config *project.Config, res prompt.Results) {
config.Name = ans.Value
},
},
*profileChoice,
@ -65,8 +65,8 @@ var initCmd = &cobra.Command{
Value: "Soft",
Details: "Prepend prefixes to each team member's deployment",
Callback: func(
ans prompt.Answer, prj *project.Project, res prompt.Results) {
prj.Isolation = project.Soft
ans prompt.Answer, config *project.Config, res prompt.Results) {
config.Isolation = project.Soft
},
},
}},
@ -92,14 +92,14 @@ var initCmd = &cobra.Command{
if err != nil {
return err
}
var prj project.Project
var config project.Config
for _, ans := range res {
if ans.Callback == nil {
continue
}
ans.Callback(ans, &prj, res)
ans.Callback(ans, &config, res)
}
raw, err := yaml.Marshal(prj)
raw, err := yaml.Marshal(config)
if err != nil {
return err
}

View File

@ -28,8 +28,8 @@ func loadCliProfiles() (profiles []prompt.Answer, err error) {
profiles = append(profiles, prompt.Answer{
Value: v.Name(),
Details: fmt.Sprintf(`Connecting to "%s" workspace`, host),
Callback: func(ans prompt.Answer, prj *project.Project, _ prompt.Results) {
prj.Profile = ans.Value
Callback: func(ans prompt.Answer, config *project.Config, _ prompt.Results) {
config.Profile = ans.Value
},
})
}

View File

@ -76,7 +76,7 @@ func (q Choice) Ask(res Results) (string, Answer, error) {
type Answers []Answer
type AnswerCallback func(ans Answer, prj *project.Project, res Results)
type AnswerCallback func(ans Answer, config *project.Config, res Results)
type Answer struct {
Value string

View File

@ -15,12 +15,14 @@ import (
var syncCmd = &cobra.Command{
Use: "sync",
Short: "run syncs for the project",
PreRunE: project.Configure,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
wsc := project.Current.WorkspacesClient()
wsc := project.Get(ctx).WorkspacesClient()
if *remotePath == "" {
me, err := project.Current.Me()
me, err := project.Get(ctx).Me()
if err != nil {
return err
}

View File

@ -27,7 +27,7 @@ type watchdog struct {
}
func putFile(ctx context.Context, path string, content io.Reader) error {
wsc := project.Current.WorkspacesClient()
wsc := project.Get(ctx).WorkspacesClient()
// workspace mkdirs is idempotent
err := wsc.Workspace.MkdirsByPath(ctx, filepath.Dir(path))
if err != nil {

2
go.mod
View File

@ -10,7 +10,7 @@ require (
github.com/mitchellh/go-homedir v1.1.0 // MIT
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // BSD-2-Clause
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // MIT
github.com/spf13/cobra v1.4.0 // Apache 2.0
github.com/spf13/cobra v1.5.0 // Apache 2.0
github.com/stretchr/testify v1.8.0 // MIT
github.com/whilp/git-urls v1.0.0 // MIT
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // BSD-3-Clause

6
go.sum
View File

@ -76,7 +76,7 @@ github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -208,8 +208,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI=
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/cobra v1.4.0 h1:y+wJpx64xcgO1V+RcnwW0LEHxTKRi2ZDPSBjWnrg88Q=
github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g=
github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU=
github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"github.com/databricks/bricks/folders"
@ -29,7 +30,7 @@ type Assertions struct {
ServicePrincipals []string `json:"service_principals,omitempty"`
}
type Project struct {
type Config struct {
Name string `json:"name"` // or do default from folder name?..
Profile string `json:"profile,omitempty"` // rename?
Isolation Isolation `json:"isolation,omitempty"`
@ -51,20 +52,20 @@ type Project struct {
Assertions *Assertions `json:"assertions,omitempty"`
}
func (p Project) IsDevClusterDefined() bool {
return reflect.ValueOf(p.DevCluster).IsZero()
func (c Config) IsDevClusterDefined() bool {
return reflect.ValueOf(c.DevCluster).IsZero()
}
// IsDevClusterJustReference denotes reference-only clusters.
// This conflicts with Soft isolation. Happens for cost-restricted projects,
// where there's only a single Shared Autoscaling cluster per workspace and
// general users have no ability to create other iteractive clusters.
func (p *Project) IsDevClusterJustReference() bool {
if p.DevCluster.ClusterName == "" {
func (c *Config) IsDevClusterJustReference() bool {
if c.DevCluster.ClusterName == "" {
return false
}
return reflect.DeepEqual(p.DevCluster, &clusters.ClusterInfo{
ClusterName: p.DevCluster.ClusterName,
return reflect.DeepEqual(c.DevCluster, &clusters.ClusterInfo{
ClusterName: c.DevCluster.ClusterName,
})
}
@ -75,12 +76,8 @@ func IsDatabricksProject() bool {
return err == nil
}
func loadProjectConf() (prj Project, err error) {
root, err := findProjectRoot()
if err != nil {
return
}
config, err := os.Open(fmt.Sprintf("%s/%s", root, ConfigFile))
func loadProjectConf(root string) (c Config, err error) {
config, err := os.Open(filepath.Join(root, ConfigFile))
if err != nil {
return
}
@ -88,20 +85,20 @@ func loadProjectConf() (prj Project, err error) {
if err != nil {
return
}
err = yaml.Unmarshal(raw, &prj)
err = yaml.Unmarshal(raw, &c)
if err != nil {
return
}
return validateAndApplyProjectDefaults(prj)
return validateAndApplyProjectDefaults(c)
}
func validateAndApplyProjectDefaults(prj Project) (Project, error) {
func validateAndApplyProjectDefaults(c Config) (Config, error) {
// defaultCluster := clusters.ClusterInfo{
// NodeTypeID: "smallest",
// SparkVersion: "latest",
// AutoterminationMinutes: 30,
// }
return prj, nil
return c, nil
}
func findProjectRoot() (string, error) {

View File

@ -1,39 +1,13 @@
package project
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFindProjectRoot(t *testing.T) {
wd, _ := os.Getwd()
defer os.Chdir(wd)
err := os.Chdir("testdata/a/b/c")
assert.NoError(t, err)
root, err := findProjectRoot()
assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("%s/testdata", wd), root)
}
func TestFindProjectRootInRoot(t *testing.T) {
wd, _ := os.Getwd()
defer os.Chdir(wd)
err := os.Chdir("/tmp")
assert.NoError(t, err)
_, err = findProjectRoot()
assert.EqualError(t, err, "cannot find databricks.yml anywhere")
}
func TestLoadProjectConf(t *testing.T) {
wd, _ := os.Getwd()
defer os.Chdir(wd)
os.Chdir("testdata/a/b/c")
prj, err := loadProjectConf()
prj, err := loadProjectConf("./testdata")
assert.NoError(t, err)
assert.Equal(t, "dev", prj.Name)
assert.True(t, prj.IsDevClusterJustReference())

View File

@ -10,73 +10,96 @@ import (
"github.com/databricks/databricks-sdk-go/service/commands"
"github.com/databricks/databricks-sdk-go/service/scim"
"github.com/databricks/databricks-sdk-go/workspaces"
"github.com/spf13/cobra"
)
// Current CLI application state - fixure out
var Current inner
type inner struct {
type project struct {
mu sync.Mutex
once sync.Once
project *Project
config *Config
wsc *workspaces.WorkspacesClient
me *scim.User
}
func (i *inner) init() {
i.mu.Lock()
defer i.mu.Unlock()
i.once.Do(func() {
prj, err := loadProjectConf()
i.wsc = workspaces.New(&databricks.Config{Profile: prj.Profile})
// Configure is used as a PreRunE function for all commands that
// require a project to be configured. If a project could successfully
// be found and loaded, it is set on the command's context object.
func Configure(cmd *cobra.Command, args []string) error {
root, err := getRoot()
if err != nil {
panic(err)
return err
}
ctx, err := Initialize(cmd.Context(), root)
if err != nil {
panic(err)
}
i.project = &prj
})
return err
}
func (i *inner) Project() *Project {
i.init()
return i.project
cmd.SetContext(ctx)
return nil
}
// Make sure to initialize the workspaces client on project init
func (i *inner) WorkspacesClient() *workspaces.WorkspacesClient {
i.init()
return i.wsc
}
// Placeholder to use as unique key in context.Context.
var projectKey int
func (i *inner) Me() (*scim.User, error) {
i.mu.Lock()
defer i.mu.Unlock()
if i.me != nil {
return i.me, nil
}
me, err := i.wsc.CurrentUser.Me(context.Background())
// Initialize loads a project configuration given a root.
// It stores the project on a new context.
// The project is available through the `Get()` function.
func Initialize(ctx context.Context, root string) (context.Context, error) {
config, err := loadProjectConf(root)
if err != nil {
return nil, err
}
i.me = me
p := project{
config: &config,
}
p.wsc = workspaces.New(&databricks.Config{Profile: config.Profile})
return context.WithValue(ctx, &projectKey, &p), nil
}
// Get returns the project as configured on the context.
// It panics if it isn't configured.
func Get(ctx context.Context) *project {
project, ok := ctx.Value(&projectKey).(*project)
if !ok {
panic(`context not configured with project`)
}
return project
}
// Make sure to initialize the workspaces client on project init
func (p *project) WorkspacesClient() *workspaces.WorkspacesClient {
return p.wsc
}
func (p *project) Me() (*scim.User, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.me != nil {
return p.me, nil
}
me, err := p.wsc.CurrentUser.Me(context.Background())
if err != nil {
return nil, err
}
p.me = me
return me, nil
}
func (i *inner) DeploymentIsolationPrefix() string {
if i.project.Isolation == None {
return i.project.Name
func (p *project) DeploymentIsolationPrefix() string {
if p.config.Isolation == None {
return p.config.Name
}
if i.project.Isolation == Soft {
me, err := i.Me()
if p.config.Isolation == Soft {
me, err := p.Me()
if err != nil {
panic(err)
}
return fmt.Sprintf("%s/%s", i.project.Name, me.UserName)
return fmt.Sprintf("%s/%s", p.config.Name, me.UserName)
}
panic(fmt.Errorf("unknow project isolation: %s", i.project.Isolation))
panic(fmt.Errorf("unknow project isolation: %s", p.config.Isolation))
}
func getClusterIdFromClusterName(ctx context.Context,
@ -101,24 +124,24 @@ func getClusterIdFromClusterName(ctx context.Context,
// Old version of getting development cluster details with isolation implemented.
// Kept just for reference. Remove once isolation is implemented properly
/*
func (i *inner) DevelopmentCluster(ctx context.Context) (cluster clusters.ClusterInfo, err error) {
api := clusters.NewClustersAPI(ctx, i.Client()) // TODO: rewrite with normal SDK
if i.project.DevCluster == nil {
i.project.DevCluster = &clusters.Cluster{}
func (p *project) DevelopmentCluster(ctx context.Context) (cluster clusters.ClusterInfo, err error) {
api := clusters.NewClustersAPI(ctx, p.Client()) // TODO: rewrite with normal SDK
if p.project.DevCluster == nil {
p.project.DevCluster = &clusters.Cluster{}
}
dc := i.project.DevCluster
if i.project.Isolation == Soft {
if i.project.IsDevClusterJustReference() {
dc := p.project.DevCluster
if p.project.Isolation == Soft {
if p.project.IsDevClusterJustReference() {
err = fmt.Errorf("projects with soft isolation cannot have named clusters")
return
}
dc.ClusterName = fmt.Sprintf("dev/%s", i.DeploymentIsolationPrefix())
dc.ClusterName = fmt.Sprintf("dev/%s", p.DeploymentIsolationPrefix())
}
if dc.ClusterName == "" {
err = fmt.Errorf("please either pick `isolation: soft` or specify a shared cluster name")
return
}
return api.GetOrCreateRunningCluster(dc.ClusterName, *dc)
return app.GetOrCreateRunningCluster(dc.ClusterName, *dc)
}
func runCommandOnDev(ctx context.Context, language, command string) common.CommandResults {
@ -138,17 +161,16 @@ func RunPythonOnDev(ctx context.Context, command string) common.CommandResults {
}
*/
// TODO: Add safe access to i.project and i.project.DevCluster that throws errors if
// TODO: Add safe access to p.project and p.project.DevCluster that throws errors if
// the fields are not defined properly
func (i *inner) GetDevelopmentClusterId(ctx context.Context) (clusterId string, err error) {
i.init()
clusterId = i.project.DevCluster.ClusterId
clusterName := i.project.DevCluster.ClusterName
func (p *project) GetDevelopmentClusterId(ctx context.Context) (clusterId string, err error) {
clusterId = p.config.DevCluster.ClusterId
clusterName := p.config.DevCluster.ClusterName
if clusterId != "" {
return
} else if clusterName != "" {
// Add workspaces client on init
return getClusterIdFromClusterName(ctx, i.wsc, clusterName)
return getClusterIdFromClusterName(ctx, p.wsc, clusterName)
} else {
// TODO: Add the project config file location used to error message
err = fmt.Errorf("please define either development cluster's cluster_id or cluster_name in your project config")
@ -157,14 +179,14 @@ func (i *inner) GetDevelopmentClusterId(ctx context.Context) (clusterId string,
}
func runCommandOnDev(ctx context.Context, language, command string) commands.CommandResults {
clusterId, err := Current.GetDevelopmentClusterId(ctx)
clusterId, err := Get(ctx).GetDevelopmentClusterId(ctx)
if err != nil {
return commands.CommandResults{
ResultType: "error",
Summary: err.Error(),
}
}
return Current.wsc.Commands.Execute(ctx, clusterId, language, command)
return Get(ctx).wsc.Commands.Execute(ctx, clusterId, language, command)
}
func RunPythonOnDev(ctx context.Context, command string) commands.CommandResults {

15
project/project_test.go Normal file
View File

@ -0,0 +1,15 @@
package project
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProjectInitialize(t *testing.T) {
ctx, err := Initialize(context.Background(), "./testdata")
require.NoError(t, err)
assert.Equal(t, Get(ctx).config.Name, "dev")
}

38
project/root.go Normal file
View File

@ -0,0 +1,38 @@
package project
import (
"fmt"
"os"
"github.com/databricks/bricks/folders"
)
const bricksRoot = "BRICKS_ROOT"
// getRoot returns the project root.
// If the `BRICKS_ROOT` environment variable is set, we assume its value
// to be a valid project root. Otherwise we try to find it by traversing
// the path and looking for a project configuration file.
func getRoot() (string, error) {
path, ok := os.LookupEnv(bricksRoot)
if ok {
stat, err := os.Stat(path)
if err == nil && !stat.IsDir() {
err = fmt.Errorf("not a directory")
}
if err != nil {
return "", fmt.Errorf(`invalid project root %s="%s": %w`, bricksRoot, path, err)
}
} else {
wd, err := os.Getwd()
if err != nil {
return "", err
}
path, err = folders.FindDirWithLeaf(wd, ConfigFile)
if err != nil {
return "", fmt.Errorf(`unable to locate project root`)
}
}
return path, nil
}

92
project/root_test.go Normal file
View File

@ -0,0 +1,92 @@
package project
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
// Changes into specified directory for the duration of the test.
// Returns the current working directory.
func chdir(t *testing.T, dir string) string {
wd, err := os.Getwd()
require.NoError(t, err)
abs, err := filepath.Abs(dir)
require.NoError(t, err)
err = os.Chdir(abs)
require.NoError(t, err)
t.Cleanup(func() {
err := os.Chdir(wd)
require.NoError(t, err)
})
return wd
}
func TestRootFromEnv(t *testing.T) {
dir := t.TempDir()
t.Setenv(bricksRoot, dir)
// It should pull the root from the environment variable.
root, err := getRoot()
require.NoError(t, err)
require.Equal(t, root, dir)
}
func TestRootFromEnvDoesntExist(t *testing.T) {
dir := t.TempDir()
t.Setenv(bricksRoot, filepath.Join(dir, "doesntexist"))
// It should pull the root from the environment variable.
_, err := getRoot()
require.Errorf(t, err, "invalid project root")
}
func TestRootFromEnvIsFile(t *testing.T) {
dir := t.TempDir()
f, err := os.Create(filepath.Join(dir, "invalid"))
require.NoError(t, err)
f.Close()
t.Setenv(bricksRoot, f.Name())
// It should pull the root from the environment variable.
_, err = getRoot()
require.Errorf(t, err, "invalid project root")
}
func TestRootIfEnvIsEmpty(t *testing.T) {
dir := ""
t.Setenv(bricksRoot, dir)
// It should pull the root from the environment variable.
_, err := getRoot()
require.Errorf(t, err, "invalid project root")
}
func TestRootLookup(t *testing.T) {
// Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(bricksRoot, "")
os.Unsetenv(bricksRoot)
// It should find the project root from $PWD.
wd := chdir(t, "./testdata/a/b/c")
root, err := getRoot()
require.NoError(t, err)
require.Equal(t, root, filepath.Join(wd, "testdata"))
}
func TestRootLookupError(t *testing.T) {
// Have to set then unset to allow the testing package to revert it to its original value.
t.Setenv(bricksRoot, "")
os.Unsetenv(bricksRoot)
// It can't find a project root from a temporary directory.
_ = chdir(t, t.TempDir())
_, err := getRoot()
require.ErrorContains(t, err, "unable to locate project root")
}

View File

@ -63,7 +63,7 @@ func UploadWheelToDBFSWithPEP503(ctx context.Context, dir string) (string, error
// extra index URLs. See more pointers at https://stackoverflow.com/q/30889494/277035
dbfsLoc := fmt.Sprintf("%s/%s/%s", DBFSWheelLocation, dist.NormalizedName(), path.Base(wheel))
wsc := project.Current.WorkspacesClient()
wsc := project.Get(ctx).WorkspacesClient()
wf, err := os.Open(wheel)
if err != nil {
return "", err