add tests

This commit is contained in:
Andrew Nester 2023-08-16 16:34:46 +02:00
parent 2e3c5ca9ea
commit a9ed5df9fe
No known key found for this signature in database
GPG Key ID: 12BC628A44B7DA57
6 changed files with 71 additions and 56 deletions

View File

@ -17,7 +17,6 @@ func getSync(ctx context.Context, b *bundle.Bundle) (*sync.Sync, error) {
opts := sync.SyncOptions{ opts := sync.SyncOptions{
LocalPath: b.Config.Path, LocalPath: b.Config.Path,
RemotePath: b.Config.Workspace.FilesPath, RemotePath: b.Config.Workspace.FilesPath,
Full: false, Full: false,
CurrentUser: b.Config.Workspace.CurrentUser.User, CurrentUser: b.Config.Workspace.CurrentUser.User,

View File

@ -114,9 +114,14 @@ func generateNotebookWrapper(b *bundle.Bundle, task *jobs.PythonWheelTask, libra
} }
defer f.Close() defer f.Close()
params, err := generateParameters(task)
if err != nil {
return "", err
}
data := map[string]any{ data := map[string]any{
"Libraries": libraries, "Libraries": libraries,
"Params": generateParameters(task), "Params": params,
"Task": task, "Task": task,
} }
@ -127,7 +132,10 @@ func generateNotebookWrapper(b *bundle.Bundle, task *jobs.PythonWheelTask, libra
return notebookName, t.Execute(f, data) return notebookName, t.Execute(f, data)
} }
func generateParameters(task *jobs.PythonWheelTask) string { func generateParameters(task *jobs.PythonWheelTask) (string, error) {
if task.Parameters != nil && task.NamedParameters != nil {
return "", fmt.Errorf("not allowed to pass both paramaters and named_parameters")
}
params := append([]string{"python"}, task.Parameters...) params := append([]string{"python"}, task.Parameters...)
for k, v := range task.NamedParameters { for k, v := range task.NamedParameters {
params = append(params, fmt.Sprintf("%s=%s", k, v)) params = append(params, fmt.Sprintf("%s=%s", k, v))
@ -135,5 +143,5 @@ func generateParameters(task *jobs.PythonWheelTask) string {
for i := range params { for i := range params {
params[i] = `"` + params[i] + `"` params[i] = `"` + params[i] + `"`
} }
return strings.Join(params, ", ") return strings.Join(params, ", "), nil
} }

View File

@ -0,0 +1,55 @@
package python
import (
"testing"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)
type testCase struct {
Actual []string
Expected string
}
type NamedParams map[string]string
type testCaseNamed struct {
Actual NamedParams
Expected string
}
var paramsTestCases []testCase = []testCase{
{[]string{}, `"python"`},
{[]string{"a"}, `"python", "a"`},
{[]string{"a", "b"}, `"python", "a", "b"`},
{[]string{"123!@#$%^&*()-="}, `"python", "123!@#$%^&*()-="`},
}
var paramsTestCasesNamed []testCaseNamed = []testCaseNamed{
{NamedParams{}, `"python"`},
{NamedParams{"a": "1"}, `"python", "a=1"`},
{NamedParams{"a": "1", "b": "2"}, `"python", "a=1", "b=2"`},
}
func TestGenerateParameters(t *testing.T) {
for _, c := range paramsTestCases {
task := &jobs.PythonWheelTask{Parameters: c.Actual}
result, err := generateParameters(task)
require.NoError(t, err)
require.Equal(t, c.Expected, result)
}
}
func TestGenerateNamedParameters(t *testing.T) {
for _, c := range paramsTestCasesNamed {
task := &jobs.PythonWheelTask{NamedParameters: c.Actual}
result, err := generateParameters(task)
require.NoError(t, err)
require.Equal(t, c.Expected, result)
}
}
func TestGenerateBoth(t *testing.T) {
task := &jobs.PythonWheelTask{NamedParameters: map[string]string{"a": "1"}, Parameters: []string{"b"}}
_, err := generateParameters(task)
require.Error(t, err)
}

View File

@ -38,7 +38,6 @@ func (f *syncFlags) syncOptionsFromBundle(cmd *cobra.Command, args []string, b *
opts := sync.SyncOptions{ opts := sync.SyncOptions{
LocalPath: b.Config.Path, LocalPath: b.Config.Path,
RemotePath: b.Config.Workspace.FilesPath, RemotePath: b.Config.Workspace.FilesPath,
Full: f.full, Full: f.full,
PollInterval: f.interval, PollInterval: f.interval,

View File

@ -1,45 +0,0 @@
package fileset
import (
"io/fs"
"os"
"path/filepath"
)
type GlobSet struct {
root string
patterns []string
}
func NewGlobSet(root string, includes []string) *GlobSet {
return &GlobSet{root, includes}
}
// Return all tracked files for Repo
func (s *GlobSet) All() ([]File, error) {
files := make([]File, 0)
for _, pattern := range s.patterns {
matches, err := filepath.Glob(pattern)
if err != nil {
return files, err
}
for _, match := range matches {
if !filepath.IsAbs(match) {
match = filepath.Join(s.root, match)
}
matchRel, err := filepath.Rel(s.root, match)
if err != nil {
return files, err
}
stat, err := os.Stat(match)
if err != nil {
return files, err
}
files = append(files, File{fs.FileInfoToDirEntry(stat), match, matchRel})
}
}
return files, nil
}

View File

@ -33,7 +33,6 @@ type Sync struct {
*SyncOptions *SyncOptions
fileSet *git.FileSet fileSet *git.FileSet
snapshot *Snapshot snapshot *Snapshot
filer filer.Filer filer filer.Filer