Compare commits

...

41 Commits

Author SHA1 Message Date
Richard Nordström facbd27774
Merge ca08796f77 into b323703c1b 2024-11-23 18:01:01 +01:00
shreyas-goenka b323703c1b
Add validation for single node clusters (#1909)
## Changes
This PR adds a warning validating that the configuration for a single
node cluster is valid for interactive, job, job-task, and pipeline
clusters.

Note: We skip the validation if a cluster policy is configured because
the policy is likely to configure `spark_conf` / `custom_tags` itself.

Note: Terrform originally only had validation for interactive, job, and
job-task clusters. This PR adding the validation for pipeline clusters
as well is new.

This PR follows the same logic as we used to have in Terraform. The
validation was removed from Terraform because we had no way to demote
the error to a warning:
https://github.com/databricks/terraform-provider-databricks/pull/4222

### Background
Single-node clusters require `spark_conf` and `custom_tags` to be
correctly set in the cluster definition for them to function optimally.
The cluster will be created even if incorrectly configured, but its
performance will not be great.

For example, if both `spark_conf` and `custom_tags` are not set and
`num_workers` is 0, then only the driver process will be launched on the
cluster compute instance thus leading to sub-optimal utilization of
available compute resources and no parallelization across worker
processes when processing a spark query.

### Issue

This PR addresses some issues reported in
https://github.com/databricks/cli/issues/1546

## Tests
Unit tests and manually.

Example output of the warning:
```
➜  bundle-playground git:(master) ✗ cli bundle validate
Warning: Single node cluster is not correctly configured
  at resources.pipelines.bar.clusters[0]
  in databricks.yml:29:11

num_workers should be 0 only for single-node clusters. To create a
valid single node cluster please ensure that the following properties
are correctly set in the cluster specification:

  spark_conf:
    spark.databricks.cluster.profile: singleNode
    spark.master: local[*]

  custom_tags:
    ResourceClass: SingleNode
  

Name: foobar
Target: default
Workspace:
  User: shreyas.goenka@databricks.com
  Path: /Workspace/Users/shreyas.goenka@databricks.com/.bundle/foobar/default

Found 1 warning
```
2024-11-22 15:48:09 +00:00
Ilya Kuznetsov 490dd058aa
Extended message for warning when source-linked mode is used outside of the workspace (#1929)
## Changes

Added path and locations to the warning which displayed when
source-linked mode is used outside of the workspace
2024-11-22 14:44:33 +00:00
Pieter Noordhuis abfd1713e0
Skip sync warning if no sync paths are defined (#1926)
## Changes

Users can configure the bundle to not synchronize any files with:
```yaml
sync:
  paths: []
```

If it is explicitly configured as an empty list, the validate command
must not warn about not having any files to synchronize. The warning
exists to alert users who are unintentionally not synchronizing any
files (they might have a `.gitignore` pattern that matches everything).

Closes #1663.

## Tests

* New unit test.
2024-11-21 15:03:13 +00:00
Pieter Noordhuis a3cea07c9e
Support lookup by name of notification destinations (#1922)
## Changes

Add support for notification destinations in variable lookups.

More information:
https://docs.databricks.com/en/admin/workspace-settings/notification-destinations.html

Depends on #1921.

## Tests

* New unit test
* Manually confirmed that the lookup works
2024-11-21 15:52:14 +01:00
shreyas-goenka abc2f3c825
Fix `TestAccBundleInitOnMlopsStacks` (#1924)
## Changes
The ML production team modified mlops-stack to use `mode: development`
for their development target here:
https://github.com/databricks/mlops-stacks/pull/174

This PR makes the integration test assertion agnostic of the prefix to
make it pass again.

## Tests
The test passes now
2024-11-21 10:46:24 +00:00
shreyas-goenka c2e2abcc35
Extend "notebook not found" error to warn about missing extension (#1920)
## Changes
The full workspace path for a notebook does not contain the notebook's
extension. If a user converts that file path to a relative path (like
`/Workspace/bundle_root/bar/nb` -> `./bar/nb`), they can be confused as
to why the new file path does not work.

The changes in this PR nudge them to add the appropriate file extension
(e.g., `./bar/nb.py` or `./bar/nb.ipynb`).

One common way users can end up in this scenario is by using the view
job as YAML functionality in the Databricks UI.

## Tests
Unit test and manually.

```
(.venv) ➜  bundle-playground git:(master) ✗ cli bundle validate 
Error: notebook ./foo not found. Local notebook references are expected
to contain one of the following file extensions: [.py, .r, .scala, .sql, .ipynb]
```
2024-11-21 16:21:21 +05:30
Pieter Noordhuis 14fe03dcb9
Breakout variable lookup into separate files and tests (#1921)
## Changes

While looking into adding variable lookups for notification destinations
([API][API]), I found the codegen approach for different classes of
variable lookups a bit complex. The template had a custom field override
(for service principals), the package had an override for the cluster
lookup, and it didn't produce tests.

The notification destinations API uses a default page size of 20 for
listing. I want to use a larger page size to limit the number of API
calls, so that would imply another customization on the template or a
manual override.

This code being rather mechanical, I used copilot to produce all
instances of the resolvers and their tests (after writing one of them
manually).

[api]: https://docs.databricks.com/api/workspace/notificationdestinations

## Tests

* Unit tests pass
* Manual confirmation that lookups of warehouses still work
2024-11-21 11:28:50 +01:00
shreyas-goenka 984c38e03e
Add unique ID to `root_path` for bundle integration test fixtures (#1917)
## Changes
Integration tests using these fixtures could have been flaky when run in
parallel using the same user's identity. They would also possibly have
piggybacked state from previous runs.

This PR adds a UUID to the root_path to force independent bundle
deployments for every test run.

I have checked that all bundles in `internal/bundle/bundles` have
`root_path` namespaced to a UUID.

## Tests
Self testing.
2024-11-20 16:30:10 +00:00
Pieter Noordhuis ade95d9649
[Release] Release v0.235.0 (#1918)
**Note:** the `bundle generate` command now uses the
`.<resource-type>.yml`
sub-extension for the configuration files it writes. Existing
configuration
files that do not use this sub-extension are renamed to include it.

Bundles:
* Make `TableName` field part of quality monitor schema
([#1903](https://github.com/databricks/cli/pull/1903)).
* Do not prepend paths starting with ~ or variable reference
([#1905](https://github.com/databricks/cli/pull/1905)).
* Fix workspace extensions filer accidentally reading notebooks
([#1891](https://github.com/databricks/cli/pull/1891)).
* Fix template initialization when running on Databricks
([#1912](https://github.com/databricks/cli/pull/1912)).
* Source-linked deployments for bundles in the workspace
([#1884](https://github.com/databricks/cli/pull/1884)).
* Added integration test to deploy bundle to /Shared root path
([#1914](https://github.com/databricks/cli/pull/1914)).
* Update filenames used by bundle generate to use `.<resource-type>.yml`
([#1901](https://github.com/databricks/cli/pull/1901)).

Internal:
* Extract functionality to detect if the CLI is running on DBR
([#1889](https://github.com/databricks/cli/pull/1889)).
* Consolidate test helpers for `io/fs`
([#1906](https://github.com/databricks/cli/pull/1906)).
* Use `fs.FS` interface to read template
([#1910](https://github.com/databricks/cli/pull/1910)).
* Use `filer.Filer` to write template instantiation
([#1911](https://github.com/databricks/cli/pull/1911)).
2024-11-20 14:48:18 +00:00
Andrew Nester 592e1111b7
Update filenames used by bundle generate to use `.<resource-type>.yml` (#1901)
## Changes
Update filenames used by bundle generate to use '.resource-type.yml'

Similar to [Add sub-extension to resource files in built-in templates by
shreyas-goenka · Pull Request #1777 ·
databricks/cli](https://github.com/databricks/cli/pull/1777)

---------

Co-authored-by: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com>
2024-11-20 13:53:25 +01:00
Andrew Nester fab3e8f168
Added integration test to deploy bundle to /Shared root path (#1914)
## Changes
Added integration test to deploy bundle to /Shared root path

## Tests
```
--- PASS: TestAccDeployBasicToSharedWorkspace (24.58s)
PASS
coverage: 31.2% of statements in ./...
ok  	github.com/databricks/cli/internal/bundle	25.572s	coverage: 31.2% of statements in ./...
```

---------

Co-authored-by: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com>
2024-11-20 12:20:39 +00:00
Ilya Kuznetsov 756e55fabc
Source-linked deployments for bundles in the workspace (#1884)
## Changes

This change adds a preset for source-linked deployments. It is enabled
by default for targets in `development` mode **if** the Databricks CLI
is running from the `/Workspace` directory on DBR. It does not have an
effect when running the CLI anywhere else.

Key highlights:
1. Files in this mode won't be uploaded to workspace
2. Created resources will use references to source files instead of
their workspace copies

## Tests
1. Apply preset unit test covering conditional logic
2. High-level process target mode unit test for testing integration
between mutators

---------

Co-authored-by: Pieter Noordhuis <pieter.noordhuis@databricks.com>
2024-11-20 13:22:27 +01:00
Pieter Noordhuis 886e14910c
Fix template initialization when running on Databricks (#1912)
## Changes

When running the CLI on Databricks Runtime (DBR), use the
extension-aware filer to write an instantiated template if the instance
path is located in the workspace filesystem.

Notebooks cannot be written through the workspace filesystem's FUSE
mount. As a result, this is the only method for initializing templates
that contain notebooks when running the CLI on DBR and writing to the
workspace filesystem.

Depends on #1910 and #1911.

Supersedes #1744.

## Tests

* Manually confirmed I can initialize a template with notebooks when
running the CLI from the web terminal.
2024-11-20 11:42:23 +00:00
Pieter Noordhuis 75b09ff230
Use `filer.Filer` to write template instantiation (#1911)
## Changes

Prior to this change, the output directory was part of the `renderer`
type and passed down to every `file` it produced. Every file knew its
absolute destination path. This is incompatible with the use of a filer,
where all operations are automatically anchored to some base path.

To make this compatible, this change updates:
* the `file` type to only know its own path relative to the instantiation root,
* the `renderer` type to no longer require or pass along the output directory,
* the `persistToDisk` function to take a context and filer argument,
* the `filer.WriteMode` to represent permission bits

## Tests

* Existing tests pass.
* Manually confirmed template initialization works as expected.
2024-11-20 11:11:31 +01:00
Pieter Noordhuis 4fea0219fd
Use `fs.FS` interface to read template (#1910)
## Changes

While working on the v2 of #1744, I found that:
* Template initialization first copies built-in templates to a temporary
directory before initializing them
* Reading a template's contents goes through a `filer.Filer` but is
hardcoded to a local one

This change updates the interface for reading templates to be `fs.FS`.
This is compatible with the `embed.FS` type for the built-in templates,
so they no longer have to be copied to a temporary directory before
being used.

The alternative is to use a `filer.Filer` throughout, but this would
have required even more plumbing, and we don't need to _read_ templates,
including notebooks, from the workspace filesystem (yet?).

As part of making `template.Materialize` take an `fs.FS` argument, the
logic to match a given argument to a particular built-in template in the
`init` command has moved to sit next to its implementation.

## Tests

Existing tests pass.
2024-11-20 09:28:35 +00:00
shreyas-goenka 72dde793d8
Fix workspace extensions filer accidentally reading notebooks (#1891)
## Changes
The workspace extensions filer should not read or stat a notebook called
`foo` if the user calls `.Stat(ctx, "foo")`.

Instead, the filer should return a file not found error. This is because
the contract for the workspace extensions filer is to only work for
notebooks when the file path / name includes the extension (example:
`foo.ipynb` or `foo.sql` instead of just `foo`)

## Tests
Integration tests.
2024-11-18 17:25:24 +00:00
Richard Nordström ca08796f77
Merge branch 'main' into feature/logout 2024-11-06 16:50:52 +01:00
Richard Nordström fc23aa584d
Merge branch 'main' into feature/logout 2024-10-22 21:10:07 +02:00
Richard Nordström 6af6b55832
Merge branch 'main' into feature/logout 2024-10-16 01:31:40 +02:00
Richard Nordström 865964e029
reduce scope for logout cmd to only remove the OAuth token 2024-10-06 23:44:11 +02:00
Richard Nordström 41999fbe87
Merge branch 'main' into feature/logout 2024-10-06 22:44:06 +02:00
Richard Nordström d2bead3fe6
Merge branch 'main' into feature/logout 2024-10-01 22:01:40 +02:00
Richard Nordström 11c37673a6
make tokenCacheMock consistent naming with struct 2024-09-23 21:55:31 +02:00
Richard Nordström 18d3fea34e
Merge branch 'main' into feature/logout 2024-09-23 21:43:07 +02:00
Richard Nordström b7ff019b60
add test for file write 2024-09-23 21:20:56 +02:00
Richard Nordström bb35ca090f
logoutSession not exportable 2024-09-23 20:37:53 +02:00
Richard Nordström d037ec32a1
add new write function to persist to disk 2024-09-23 20:26:43 +02:00
Richard Nordström 89d3b1a4df
remove redundant version specification 2024-09-23 20:23:09 +02:00
Richard Nordström 37067ef933
rename DeleteKey to Delete 2024-09-23 20:21:38 +02:00
Richard Nordström 171c3fdd75
Merge branch 'main' into feature/logout 2024-09-19 21:02:09 +02:00
Richard Nordström dc44dbd667
Merge branch 'main' into feature/logout 2024-09-07 00:45:28 +02:00
Richard Nordström b044a6c0e0
Merge branch 'main' into feature/logout 2024-09-04 14:00:23 +02:00
Richard Nordström 7636c55ba9
Merge branch 'main' into feature/logout 2024-09-03 21:02:45 +02:00
Richard Nordström e88fd0a5c0
Merge branch 'main' into feature/logout 2024-09-02 17:26:33 +02:00
Richard Nordström 6c32a0df7a
improve profile handling and add tests 2024-09-02 00:15:48 +02:00
Richard Nordström 7eca34a7b2
fix typo 2024-09-01 20:25:40 +02:00
Richard Nordström 6277cf24c6
allow no OAuth in case PAT is used 2024-09-01 20:20:22 +02:00
Richard Nordström 6a8b2f452f
use PersistentAuth struc 2024-09-01 20:20:21 +02:00
Richard Nordström 712e2919f5
add logout cmd 2024-09-01 20:20:21 +02:00
Richard Nordström 882ccba0f5
add DeleteKey to TokenCache for logout cmd 2024-09-01 20:20:21 +02:00
85 changed files with 3353 additions and 809 deletions

View File

@ -5,8 +5,7 @@
}, },
"batch": { "batch": {
".codegen/cmds-workspace.go.tmpl": "cmd/workspace/cmd.go", ".codegen/cmds-workspace.go.tmpl": "cmd/workspace/cmd.go",
".codegen/cmds-account.go.tmpl": "cmd/account/cmd.go", ".codegen/cmds-account.go.tmpl": "cmd/account/cmd.go"
".codegen/lookup.go.tmpl": "bundle/config/variable/lookup.go"
}, },
"toolchain": { "toolchain": {
"required": ["go"], "required": ["go"],

View File

@ -1,134 +0,0 @@
// Code generated from OpenAPI specs by Databricks SDK Generator. DO NOT EDIT.
package variable
{{ $allowlist :=
list
"alerts"
"clusters"
"cluster-policies"
"clusters"
"dashboards"
"instance-pools"
"jobs"
"metastores"
"pipelines"
"service-principals"
"queries"
"warehouses"
}}
{{ $customField :=
dict
"service-principals" "ApplicationId"
}}
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type Lookup struct {
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
{{.Singular.PascalName}} string `json:"{{.Singular.SnakeName}},omitempty"`
{{end}}
{{- end}}
}
func LookupFromMap(m map[string]any) *Lookup {
l := &Lookup{}
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
if v, ok := m["{{.Singular.SnakeName}}"]; ok {
l.{{.Singular.PascalName}} = v.(string)
}
{{end -}}
{{- end}}
return l
}
func (l *Lookup) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
if err := l.validate(); err != nil {
return "", err
}
r := allResolvers()
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
if l.{{.Singular.PascalName}} != "" {
return r.{{.Singular.PascalName}}(ctx, w, l.{{.Singular.PascalName}})
}
{{end -}}
{{- end}}
return "", fmt.Errorf("no valid lookup fields provided")
}
func (l *Lookup) String() string {
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
if l.{{.Singular.PascalName}} != "" {
return fmt.Sprintf("{{.Singular.KebabName}}: %s", l.{{.Singular.PascalName}})
}
{{end -}}
{{- end}}
return ""
}
func (l *Lookup) validate() error {
// Validate that only one field is set
count := 0
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
if l.{{.Singular.PascalName}} != "" {
count++
}
{{end -}}
{{- end}}
if count != 1 {
return fmt.Errorf("exactly one lookup field must be provided")
}
if strings.Contains(l.String(), "${var") {
return fmt.Errorf("lookup fields cannot contain variable references")
}
return nil
}
type resolverFunc func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error)
type resolvers struct {
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
{{.Singular.PascalName}} resolverFunc
{{end -}}
{{- end}}
}
func allResolvers() *resolvers {
r := &resolvers{}
{{range .Services -}}
{{- if in $allowlist .KebabName -}}
r.{{.Singular.PascalName}} = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["{{.Singular.PascalName}}"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.{{.PascalName}}.GetBy{{range .NamedIdMap.NamePath}}{{.PascalName}}{{end}}(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.{{ getOrDefault $customField .KebabName ((index .NamedIdMap.IdPath 0).PascalName) }}), nil
}
{{end -}}
{{- end}}
return r
}

1
.gitattributes vendored
View File

@ -1,4 +1,3 @@
bundle/config/variable/lookup.go linguist-generated=true
cmd/account/access-control/access-control.go linguist-generated=true cmd/account/access-control/access-control.go linguist-generated=true
cmd/account/billable-usage/billable-usage.go linguist-generated=true cmd/account/billable-usage/billable-usage.go linguist-generated=true
cmd/account/budgets/budgets.go linguist-generated=true cmd/account/budgets/budgets.go linguist-generated=true

View File

@ -1,5 +1,28 @@
# Version changelog # Version changelog
## [Release] Release v0.235.0
**Note:** the `bundle generate` command now uses the `.<resource-type>.yml`
sub-extension for the configuration files it writes. Existing configuration
files that do not use this sub-extension are renamed to include it.
Bundles:
* Make `TableName` field part of quality monitor schema ([#1903](https://github.com/databricks/cli/pull/1903)).
* Do not prepend paths starting with ~ or variable reference ([#1905](https://github.com/databricks/cli/pull/1905)).
* Fix workspace extensions filer accidentally reading notebooks ([#1891](https://github.com/databricks/cli/pull/1891)).
* Fix template initialization when running on Databricks ([#1912](https://github.com/databricks/cli/pull/1912)).
* Source-linked deployments for bundles in the workspace ([#1884](https://github.com/databricks/cli/pull/1884)).
* Added integration test to deploy bundle to /Shared root path ([#1914](https://github.com/databricks/cli/pull/1914)).
* Update filenames used by bundle generate to use `.<resource-type>.yml` ([#1901](https://github.com/databricks/cli/pull/1901)).
Internal:
* Extract functionality to detect if the CLI is running on DBR ([#1889](https://github.com/databricks/cli/pull/1889)).
* Consolidate test helpers for `io/fs` ([#1906](https://github.com/databricks/cli/pull/1906)).
* Use `fs.FS` interface to read template ([#1910](https://github.com/databricks/cli/pull/1910)).
* Use `filer.Filer` to write template instantiation ([#1911](https://github.com/databricks/cli/pull/1911)).
## [Release] Release v0.234.0 ## [Release] Release v0.234.0
Bundles: Bundles:

View File

@ -9,6 +9,7 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/textutil" "github.com/databricks/cli/libs/textutil"
@ -221,6 +222,27 @@ func (m *applyPresets) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnos
dashboard.DisplayName = prefix + dashboard.DisplayName dashboard.DisplayName = prefix + dashboard.DisplayName
} }
if config.IsExplicitlyEnabled((b.Config.Presets.SourceLinkedDeployment)) {
isDatabricksWorkspace := dbr.RunsOnRuntime(ctx) && strings.HasPrefix(b.SyncRootPath, "/Workspace/")
if !isDatabricksWorkspace {
target := b.Config.Bundle.Target
path := dyn.NewPath(dyn.Key("targets"), dyn.Key(target), dyn.Key("presets"), dyn.Key("source_linked_deployment"))
diags = diags.Append(
diag.Diagnostic{
Severity: diag.Warning,
Summary: "source-linked deployment is available only in the Databricks Workspace",
Paths: []dyn.Path{
path,
},
Locations: b.Config.GetLocations(path[2:].String()),
},
)
disabled := false
b.Config.Presets.SourceLinkedDeployment = &disabled
}
}
return diags return diags
} }

View File

@ -2,12 +2,16 @@ package mutator_test
import ( import (
"context" "context"
"runtime"
"testing" "testing"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/mutator" "github.com/databricks/cli/bundle/config/mutator"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/internal/bundletest"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -364,3 +368,88 @@ func TestApplyPresetsResourceNotDefined(t *testing.T) {
}) })
} }
} }
func TestApplyPresetsSourceLinkedDeployment(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("this test is not applicable on Windows because source-linked mode works only in the Databricks Workspace")
}
testContext := context.Background()
enabled := true
disabled := false
workspacePath := "/Workspace/user.name@company.com"
tests := []struct {
bundlePath string
ctx context.Context
name string
initialValue *bool
expectedValue *bool
expectedWarning string
}{
{
name: "preset enabled, bundle in Workspace, databricks runtime",
bundlePath: workspacePath,
ctx: dbr.MockRuntime(testContext, true),
initialValue: &enabled,
expectedValue: &enabled,
},
{
name: "preset enabled, bundle not in Workspace, databricks runtime",
bundlePath: "/Users/user.name@company.com",
ctx: dbr.MockRuntime(testContext, true),
initialValue: &enabled,
expectedValue: &disabled,
expectedWarning: "source-linked deployment is available only in the Databricks Workspace",
},
{
name: "preset enabled, bundle in Workspace, not databricks runtime",
bundlePath: workspacePath,
ctx: dbr.MockRuntime(testContext, false),
initialValue: &enabled,
expectedValue: &disabled,
expectedWarning: "source-linked deployment is available only in the Databricks Workspace",
},
{
name: "preset disabled, bundle in Workspace, databricks runtime",
bundlePath: workspacePath,
ctx: dbr.MockRuntime(testContext, true),
initialValue: &disabled,
expectedValue: &disabled,
},
{
name: "preset nil, bundle in Workspace, databricks runtime",
bundlePath: workspacePath,
ctx: dbr.MockRuntime(testContext, true),
initialValue: nil,
expectedValue: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &bundle.Bundle{
SyncRootPath: tt.bundlePath,
Config: config.Root{
Presets: config.Presets{
SourceLinkedDeployment: tt.initialValue,
},
},
}
bundletest.SetLocation(b, "presets.source_linked_deployment", []dyn.Location{{File: "databricks.yml"}})
diags := bundle.Apply(tt.ctx, b, mutator.ApplyPresets())
if diags.HasError() {
t.Fatalf("unexpected error: %v", diags)
}
if tt.expectedWarning != "" {
require.Equal(t, tt.expectedWarning, diags[0].Summary)
require.NotEmpty(t, diags[0].Locations)
}
require.Equal(t, tt.expectedValue, b.Config.Presets.SourceLinkedDeployment)
})
}
}

View File

@ -6,6 +6,7 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/iamutil" "github.com/databricks/cli/libs/iamutil"
@ -57,6 +58,14 @@ func transformDevelopmentMode(ctx context.Context, b *bundle.Bundle) {
t.TriggerPauseStatus = config.Paused t.TriggerPauseStatus = config.Paused
} }
if !config.IsExplicitlyDisabled(t.SourceLinkedDeployment) {
isInWorkspace := strings.HasPrefix(b.SyncRootPath, "/Workspace/")
if isInWorkspace && dbr.RunsOnRuntime(ctx) {
enabled := true
t.SourceLinkedDeployment = &enabled
}
}
if !config.IsExplicitlyDisabled(t.PipelinesDevelopment) { if !config.IsExplicitlyDisabled(t.PipelinesDevelopment) {
enabled := true enabled := true
t.PipelinesDevelopment = &enabled t.PipelinesDevelopment = &enabled

View File

@ -3,14 +3,17 @@ package mutator
import ( import (
"context" "context"
"reflect" "reflect"
"runtime"
"strings" "strings"
"testing" "testing"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources" "github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/tags" "github.com/databricks/cli/libs/tags"
"github.com/databricks/cli/libs/vfs"
sdkconfig "github.com/databricks/databricks-sdk-go/config" sdkconfig "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/compute"
@ -140,6 +143,7 @@ func mockBundle(mode config.Mode) *bundle.Bundle {
}, },
}, },
}, },
SyncRoot: vfs.MustNew("/Users/lennart.kats@databricks.com"),
// Use AWS implementation for testing. // Use AWS implementation for testing.
Tagging: tags.ForCloud(&sdkconfig.Config{ Tagging: tags.ForCloud(&sdkconfig.Config{
Host: "https://company.cloud.databricks.com", Host: "https://company.cloud.databricks.com",
@ -522,3 +526,32 @@ func TestPipelinesDevelopmentDisabled(t *testing.T) {
assert.False(t, b.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development) assert.False(t, b.Config.Resources.Pipelines["pipeline1"].PipelineSpec.Development)
} }
func TestSourceLinkedDeploymentEnabled(t *testing.T) {
b, diags := processSourceLinkedBundle(t, true)
require.NoError(t, diags.Error())
assert.True(t, *b.Config.Presets.SourceLinkedDeployment)
}
func TestSourceLinkedDeploymentDisabled(t *testing.T) {
b, diags := processSourceLinkedBundle(t, false)
require.NoError(t, diags.Error())
assert.False(t, *b.Config.Presets.SourceLinkedDeployment)
}
func processSourceLinkedBundle(t *testing.T, presetEnabled bool) (*bundle.Bundle, diag.Diagnostics) {
if runtime.GOOS == "windows" {
t.Skip("this test is not applicable on Windows because source-linked mode works only in the Databricks Workspace")
}
b := mockBundle(config.Development)
workspacePath := "/Workspace/lennart@company.com/"
b.SyncRootPath = workspacePath
b.Config.Presets.SourceLinkedDeployment = &presetEnabled
ctx := dbr.MockRuntime(context.Background(), true)
m := bundle.Seq(ProcessTargetMode(), ApplyPresets())
diags := bundle.Apply(ctx, b, m)
return b, diags
}

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn" "github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/notebook" "github.com/databricks/cli/libs/notebook"
@ -103,8 +104,13 @@ func (t *translateContext) rewritePath(
return fmt.Errorf("path %s is not contained in sync root path", localPath) return fmt.Errorf("path %s is not contained in sync root path", localPath)
} }
// Prefix remote path with its remote root path. var workspacePath string
remotePath := path.Join(t.b.Config.Workspace.FilePath, filepath.ToSlash(localRelPath)) if config.IsExplicitlyEnabled(t.b.Config.Presets.SourceLinkedDeployment) {
workspacePath = t.b.SyncRootPath
} else {
workspacePath = t.b.Config.Workspace.FilePath
}
remotePath := path.Join(workspacePath, filepath.ToSlash(localRelPath))
// Convert local path into workspace path via specified function. // Convert local path into workspace path via specified function.
interp, err := fn(*p, localPath, localRelPath, remotePath) interp, err := fn(*p, localPath, localRelPath, remotePath)
@ -120,7 +126,33 @@ func (t *translateContext) rewritePath(
func (t *translateContext) translateNotebookPath(literal, localFullPath, localRelPath, remotePath string) (string, error) { func (t *translateContext) translateNotebookPath(literal, localFullPath, localRelPath, remotePath string) (string, error) {
nb, _, err := notebook.DetectWithFS(t.b.SyncRoot, filepath.ToSlash(localRelPath)) nb, _, err := notebook.DetectWithFS(t.b.SyncRoot, filepath.ToSlash(localRelPath))
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
return "", fmt.Errorf("notebook %s not found", literal) if filepath.Ext(localFullPath) != notebook.ExtensionNone {
return "", fmt.Errorf("notebook %s not found", literal)
}
extensions := []string{
notebook.ExtensionPython,
notebook.ExtensionR,
notebook.ExtensionScala,
notebook.ExtensionSql,
notebook.ExtensionJupyter,
}
// Check whether a file with a notebook extension already exists. This
// way we can provide a more targeted error message.
for _, ext := range extensions {
literalWithExt := literal + ext
localRelPathWithExt := filepath.ToSlash(localRelPath + ext)
if _, err := fs.Stat(t.b.SyncRoot, localRelPathWithExt); err == nil {
return "", fmt.Errorf(`notebook %s not found. Did you mean %s?
Local notebook references are expected to contain one of the following
file extensions: [%s]`, literal, literalWithExt, strings.Join(extensions, ", "))
}
}
// Return a generic error message if no matching possible file is found.
return "", fmt.Errorf(`notebook %s not found. Local notebook references are expected
to contain one of the following file extensions: [%s]`, literal, strings.Join(extensions, ", "))
} }
if err != nil { if err != nil {
return "", fmt.Errorf("unable to determine if %s is a notebook: %w", localFullPath, err) return "", fmt.Errorf("unable to determine if %s is a notebook: %w", localFullPath, err)

View File

@ -2,8 +2,10 @@ package mutator_test
import ( import (
"context" "context"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
@ -507,6 +509,59 @@ func TestPipelineNotebookDoesNotExistError(t *testing.T) {
assert.EqualError(t, diags.Error(), "notebook ./doesnt_exist.py not found") assert.EqualError(t, diags.Error(), "notebook ./doesnt_exist.py not found")
} }
func TestPipelineNotebookDoesNotExistErrorWithoutExtension(t *testing.T) {
for _, ext := range []string{
".py",
".r",
".scala",
".sql",
".ipynb",
"",
} {
t.Run("case_"+ext, func(t *testing.T) {
dir := t.TempDir()
if ext != "" {
touchEmptyFile(t, filepath.Join(dir, "foo"+ext))
}
b := &bundle.Bundle{
SyncRootPath: dir,
SyncRoot: vfs.MustNew(dir),
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"pipeline": {
PipelineSpec: &pipelines.PipelineSpec{
Libraries: []pipelines.PipelineLibrary{
{
Notebook: &pipelines.NotebookLibrary{
Path: "./foo",
},
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, ".", []dyn.Location{{File: filepath.Join(dir, "fake.yml")}})
diags := bundle.Apply(context.Background(), b, mutator.TranslatePaths())
if ext == "" {
assert.EqualError(t, diags.Error(), `notebook ./foo not found. Local notebook references are expected
to contain one of the following file extensions: [.py, .r, .scala, .sql, .ipynb]`)
} else {
assert.EqualError(t, diags.Error(), fmt.Sprintf(`notebook ./foo not found. Did you mean ./foo%s?
Local notebook references are expected to contain one of the following
file extensions: [.py, .r, .scala, .sql, .ipynb]`, ext))
}
})
}
}
func TestPipelineFileDoesNotExistError(t *testing.T) { func TestPipelineFileDoesNotExistError(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
@ -787,3 +842,163 @@ func TestTranslatePathWithComplexVariables(t *testing.T) {
b.Config.Resources.Jobs["job"].Tasks[0].Libraries[0].Whl, b.Config.Resources.Jobs["job"].Tasks[0].Libraries[0].Whl,
) )
} }
func TestTranslatePathsWithSourceLinkedDeployment(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("this test is not applicable on Windows because source-linked mode works only in the Databricks Workspace")
}
dir := t.TempDir()
touchNotebookFile(t, filepath.Join(dir, "my_job_notebook.py"))
touchNotebookFile(t, filepath.Join(dir, "my_pipeline_notebook.py"))
touchEmptyFile(t, filepath.Join(dir, "my_python_file.py"))
touchEmptyFile(t, filepath.Join(dir, "dist", "task.jar"))
touchEmptyFile(t, filepath.Join(dir, "requirements.txt"))
enabled := true
b := &bundle.Bundle{
SyncRootPath: dir,
SyncRoot: vfs.MustNew(dir),
Config: config.Root{
Workspace: config.Workspace{
FilePath: "/bundle",
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"job": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
NotebookTask: &jobs.NotebookTask{
NotebookPath: "my_job_notebook.py",
},
Libraries: []compute.Library{
{Whl: "./dist/task.whl"},
},
},
{
NotebookTask: &jobs.NotebookTask{
NotebookPath: "/Users/jane.doe@databricks.com/absolute_remote.py",
},
},
{
NotebookTask: &jobs.NotebookTask{
NotebookPath: "my_job_notebook.py",
},
Libraries: []compute.Library{
{Requirements: "requirements.txt"},
},
},
{
SparkPythonTask: &jobs.SparkPythonTask{
PythonFile: "my_python_file.py",
},
},
{
SparkJarTask: &jobs.SparkJarTask{
MainClassName: "HelloWorld",
},
Libraries: []compute.Library{
{Jar: "./dist/task.jar"},
},
},
{
SparkJarTask: &jobs.SparkJarTask{
MainClassName: "HelloWorldRemote",
},
Libraries: []compute.Library{
{Jar: "dbfs:/bundle/dist/task_remote.jar"},
},
},
},
},
},
},
Pipelines: map[string]*resources.Pipeline{
"pipeline": {
PipelineSpec: &pipelines.PipelineSpec{
Libraries: []pipelines.PipelineLibrary{
{
Notebook: &pipelines.NotebookLibrary{
Path: "my_pipeline_notebook.py",
},
},
{
Notebook: &pipelines.NotebookLibrary{
Path: "/Users/jane.doe@databricks.com/absolute_remote.py",
},
},
{
File: &pipelines.FileLibrary{
Path: "my_python_file.py",
},
},
},
},
},
},
},
Presets: config.Presets{
SourceLinkedDeployment: &enabled,
},
},
}
bundletest.SetLocation(b, ".", []dyn.Location{{File: filepath.Join(dir, "resource.yml")}})
diags := bundle.Apply(context.Background(), b, mutator.TranslatePaths())
require.NoError(t, diags.Error())
// updated to source path
assert.Equal(
t,
filepath.Join(dir, "my_job_notebook"),
b.Config.Resources.Jobs["job"].Tasks[0].NotebookTask.NotebookPath,
)
assert.Equal(
t,
filepath.Join(dir, "requirements.txt"),
b.Config.Resources.Jobs["job"].Tasks[2].Libraries[0].Requirements,
)
assert.Equal(
t,
filepath.Join(dir, "my_python_file.py"),
b.Config.Resources.Jobs["job"].Tasks[3].SparkPythonTask.PythonFile,
)
assert.Equal(
t,
filepath.Join(dir, "my_pipeline_notebook"),
b.Config.Resources.Pipelines["pipeline"].Libraries[0].Notebook.Path,
)
assert.Equal(
t,
filepath.Join(dir, "my_python_file.py"),
b.Config.Resources.Pipelines["pipeline"].Libraries[2].File.Path,
)
// left as is
assert.Equal(
t,
filepath.Join("dist", "task.whl"),
b.Config.Resources.Jobs["job"].Tasks[0].Libraries[0].Whl,
)
assert.Equal(
t,
"/Users/jane.doe@databricks.com/absolute_remote.py",
b.Config.Resources.Jobs["job"].Tasks[1].NotebookTask.NotebookPath,
)
assert.Equal(
t,
filepath.Join("dist", "task.jar"),
b.Config.Resources.Jobs["job"].Tasks[4].Libraries[0].Jar,
)
assert.Equal(
t,
"dbfs:/bundle/dist/task_remote.jar",
b.Config.Resources.Jobs["job"].Tasks[5].Libraries[0].Jar,
)
assert.Equal(
t,
"/Users/jane.doe@databricks.com/absolute_remote.py",
b.Config.Resources.Pipelines["pipeline"].Libraries[1].Notebook.Path,
)
}

View File

@ -17,6 +17,11 @@ type Presets struct {
// JobsMaxConcurrentRuns is the default value for the max concurrent runs of jobs. // JobsMaxConcurrentRuns is the default value for the max concurrent runs of jobs.
JobsMaxConcurrentRuns int `json:"jobs_max_concurrent_runs,omitempty"` JobsMaxConcurrentRuns int `json:"jobs_max_concurrent_runs,omitempty"`
// SourceLinkedDeployment indicates whether source-linked deployment is enabled. Works only in Databricks Workspace
// When set to true, resources created during deployment will point to source files in the workspace instead of their workspace copies.
// File synchronization to ${workspace.file_path} is skipped.
SourceLinkedDeployment *bool `json:"source_linked_deployment,omitempty"`
// Tags to add to all resources. // Tags to add to all resources.
Tags map[string]string `json:"tags,omitempty"` Tags map[string]string `json:"tags,omitempty"`
} }

View File

@ -21,6 +21,12 @@ func (v *filesToSync) Name() string {
} }
func (v *filesToSync) Apply(ctx context.Context, rb bundle.ReadOnlyBundle) diag.Diagnostics { func (v *filesToSync) Apply(ctx context.Context, rb bundle.ReadOnlyBundle) diag.Diagnostics {
// The user may be intentional about not synchronizing any files.
// In this case, we should not show any warnings.
if len(rb.Config().Sync.Paths) == 0 {
return nil
}
sync, err := files.GetSync(ctx, rb) sync, err := files.GetSync(ctx, rb)
if err != nil { if err != nil {
return diag.FromErr(err) return diag.FromErr(err)
@ -31,6 +37,7 @@ func (v *filesToSync) Apply(ctx context.Context, rb bundle.ReadOnlyBundle) diag.
return diag.FromErr(err) return diag.FromErr(err)
} }
// If there are files to sync, we don't need to show any warnings.
if len(fl) != 0 { if len(fl) != 0 {
return nil return nil
} }

View File

@ -0,0 +1,105 @@
package validate
import (
"context"
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/vfs"
sdkconfig "github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/workspace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestFilesToSync_NoPaths(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Sync: config.Sync{
Paths: []string{},
},
},
}
ctx := context.Background()
rb := bundle.ReadOnly(b)
diags := bundle.ApplyReadOnly(ctx, rb, FilesToSync())
assert.Empty(t, diags)
}
func setupBundleForFilesToSyncTest(t *testing.T) *bundle.Bundle {
dir := t.TempDir()
testutil.Touch(t, dir, "file1")
testutil.Touch(t, dir, "file2")
b := &bundle.Bundle{
BundleRootPath: dir,
BundleRoot: vfs.MustNew(dir),
SyncRootPath: dir,
SyncRoot: vfs.MustNew(dir),
Config: config.Root{
Bundle: config.Bundle{
Target: "default",
},
Workspace: config.Workspace{
FilePath: "/this/doesnt/matter",
CurrentUser: &config.User{
User: &iam.User{},
},
},
Sync: config.Sync{
// Paths are relative to [SyncRootPath].
Paths: []string{"."},
},
},
}
m := mocks.NewMockWorkspaceClient(t)
m.WorkspaceClient.Config = &sdkconfig.Config{
Host: "https://foo.com",
}
// The initialization logic in [sync.New] performs a check on the destination path.
// Removing this check at initialization time is tbd...
m.GetMockWorkspaceAPI().EXPECT().GetStatusByPath(mock.Anything, "/this/doesnt/matter").Return(&workspace.ObjectInfo{
ObjectType: workspace.ObjectTypeDirectory,
}, nil)
b.SetWorkpaceClient(m.WorkspaceClient)
return b
}
func TestFilesToSync_EverythingIgnored(t *testing.T) {
b := setupBundleForFilesToSyncTest(t)
// Ignore all files.
testutil.WriteFile(t, "*\n.*\n", b.BundleRootPath, ".gitignore")
ctx := context.Background()
rb := bundle.ReadOnly(b)
diags := bundle.ApplyReadOnly(ctx, rb, FilesToSync())
require.Equal(t, 1, len(diags))
assert.Equal(t, diag.Warning, diags[0].Severity)
assert.Equal(t, "There are no files to sync, please check your .gitignore", diags[0].Summary)
}
func TestFilesToSync_EverythingExcluded(t *testing.T) {
b := setupBundleForFilesToSyncTest(t)
// Exclude all files.
b.Config.Sync.Exclude = []string{"*"}
ctx := context.Background()
rb := bundle.ReadOnly(b)
diags := bundle.ApplyReadOnly(ctx, rb, FilesToSync())
require.Equal(t, 1, len(diags))
assert.Equal(t, diag.Warning, diags[0].Severity)
assert.Equal(t, "There are no files to sync, please check your .gitignore and sync.exclude configuration", diags[0].Summary)
}

View File

@ -0,0 +1,137 @@
package validate
import (
"context"
"strings"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/cli/libs/log"
)
// Validates that any single node clusters defined in the bundle are correctly configured.
func SingleNodeCluster() bundle.ReadOnlyMutator {
return &singleNodeCluster{}
}
type singleNodeCluster struct{}
func (m *singleNodeCluster) Name() string {
return "validate:SingleNodeCluster"
}
const singleNodeWarningDetail = `num_workers should be 0 only for single-node clusters. To create a
valid single node cluster please ensure that the following properties
are correctly set in the cluster specification:
spark_conf:
spark.databricks.cluster.profile: singleNode
spark.master: local[*]
custom_tags:
ResourceClass: SingleNode
`
const singleNodeWarningSummary = `Single node cluster is not correctly configured`
func showSingleNodeClusterWarning(ctx context.Context, v dyn.Value) bool {
// Check if the user has explicitly set the num_workers to 0. Skip the warning
// if that's not the case.
numWorkers, ok := v.Get("num_workers").AsInt()
if !ok || numWorkers > 0 {
return false
}
// Convenient type that contains the common fields from compute.ClusterSpec and
// pipelines.PipelineCluster that we are interested in.
type ClusterConf struct {
SparkConf map[string]string `json:"spark_conf"`
CustomTags map[string]string `json:"custom_tags"`
PolicyId string `json:"policy_id"`
}
conf := &ClusterConf{}
err := convert.ToTyped(conf, v)
if err != nil {
return false
}
// If the policy id is set, we don't want to show the warning. This is because
// the user might have configured `spark_conf` and `custom_tags` correctly
// in their cluster policy.
if conf.PolicyId != "" {
return false
}
profile, ok := conf.SparkConf["spark.databricks.cluster.profile"]
if !ok {
log.Debugf(ctx, "spark_conf spark.databricks.cluster.profile not found in single-node cluster spec")
return true
}
if profile != "singleNode" {
log.Debugf(ctx, "spark_conf spark.databricks.cluster.profile is not singleNode in single-node cluster spec: %s", profile)
return true
}
master, ok := conf.SparkConf["spark.master"]
if !ok {
log.Debugf(ctx, "spark_conf spark.master not found in single-node cluster spec")
return true
}
if !strings.HasPrefix(master, "local") {
log.Debugf(ctx, "spark_conf spark.master does not start with local in single-node cluster spec: %s", master)
return true
}
resourceClass, ok := conf.CustomTags["ResourceClass"]
if !ok {
log.Debugf(ctx, "custom_tag ResourceClass not found in single-node cluster spec")
return true
}
if resourceClass != "SingleNode" {
log.Debugf(ctx, "custom_tag ResourceClass is not SingleNode in single-node cluster spec: %s", resourceClass)
return true
}
return false
}
func (m *singleNodeCluster) Apply(ctx context.Context, rb bundle.ReadOnlyBundle) diag.Diagnostics {
diags := diag.Diagnostics{}
patterns := []dyn.Pattern{
// Interactive clusters
dyn.NewPattern(dyn.Key("resources"), dyn.Key("clusters"), dyn.AnyKey()),
// Job clusters
dyn.NewPattern(dyn.Key("resources"), dyn.Key("jobs"), dyn.AnyKey(), dyn.Key("job_clusters"), dyn.AnyIndex(), dyn.Key("new_cluster")),
// Job task clusters
dyn.NewPattern(dyn.Key("resources"), dyn.Key("jobs"), dyn.AnyKey(), dyn.Key("tasks"), dyn.AnyIndex(), dyn.Key("new_cluster")),
// Job for each task clusters
dyn.NewPattern(dyn.Key("resources"), dyn.Key("jobs"), dyn.AnyKey(), dyn.Key("tasks"), dyn.AnyIndex(), dyn.Key("for_each_task"), dyn.Key("task"), dyn.Key("new_cluster")),
// Pipeline clusters
dyn.NewPattern(dyn.Key("resources"), dyn.Key("pipelines"), dyn.AnyKey(), dyn.Key("clusters"), dyn.AnyIndex()),
}
for _, p := range patterns {
_, err := dyn.MapByPattern(rb.Config().Value(), p, func(p dyn.Path, v dyn.Value) (dyn.Value, error) {
warning := diag.Diagnostic{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: v.Locations(),
Paths: []dyn.Path{p},
}
if showSingleNodeClusterWarning(ctx, v) {
diags = append(diags, warning)
}
return v, nil
})
if err != nil {
log.Debugf(ctx, "Error while applying single node cluster validation: %s", err)
}
}
return diags
}

View File

@ -0,0 +1,566 @@
package validate
import (
"context"
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/internal/bundletest"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/assert"
)
func failCases() []struct {
name string
sparkConf map[string]string
customTags map[string]string
} {
return []struct {
name string
sparkConf map[string]string
customTags map[string]string
}{
{
name: "no tags or conf",
},
{
name: "no tags",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
"spark.master": "local[*]",
},
},
{
name: "no conf",
customTags: map[string]string{"ResourceClass": "SingleNode"},
},
{
name: "invalid spark cluster profile",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "invalid",
"spark.master": "local[*]",
},
customTags: map[string]string{"ResourceClass": "SingleNode"},
},
{
name: "invalid spark.master",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
"spark.master": "invalid",
},
customTags: map[string]string{"ResourceClass": "SingleNode"},
},
{
name: "invalid tags",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
"spark.master": "local[*]",
},
customTags: map[string]string{"ResourceClass": "invalid"},
},
{
name: "missing ResourceClass tag",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
"spark.master": "local[*]",
},
customTags: map[string]string{"what": "ever"},
},
{
name: "missing spark.master",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
},
customTags: map[string]string{"ResourceClass": "SingleNode"},
},
{
name: "missing spark.databricks.cluster.profile",
sparkConf: map[string]string{
"spark.master": "local[*]",
},
customTags: map[string]string{"ResourceClass": "SingleNode"},
},
}
}
func TestValidateSingleNodeClusterFailForInteractiveClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range failCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Clusters: map[string]*resources.Cluster{
"foo": {
ClusterSpec: &compute.ClusterSpec{
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.clusters.foo", []dyn.Location{{File: "a.yml", Line: 1, Column: 1}})
// We can't set num_workers to 0 explicitly in the typed configuration.
// Do it on the dyn.Value directly.
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.clusters.foo.num_workers", dyn.V(0))
})
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Equal(t, diag.Diagnostics{
{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: []dyn.Location{{File: "a.yml", Line: 1, Column: 1}},
Paths: []dyn.Path{dyn.NewPath(dyn.Key("resources"), dyn.Key("clusters"), dyn.Key("foo"))},
},
}, diags)
})
}
}
func TestValidateSingleNodeClusterFailForJobClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range failCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
JobClusters: []jobs.JobCluster{
{
NewCluster: compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
},
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.jobs.foo.job_clusters[0].new_cluster", []dyn.Location{{File: "b.yml", Line: 1, Column: 1}})
// We can't set num_workers to 0 explicitly in the typed configuration.
// Do it on the dyn.Value directly.
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.job_clusters[0].new_cluster.num_workers", dyn.V(0))
})
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Equal(t, diag.Diagnostics{
{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: []dyn.Location{{File: "b.yml", Line: 1, Column: 1}},
Paths: []dyn.Path{dyn.MustPathFromString("resources.jobs.foo.job_clusters[0].new_cluster")},
},
}, diags)
})
}
}
func TestValidateSingleNodeClusterFailForJobTaskClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range failCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
NewCluster: &compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
},
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.jobs.foo.tasks[0].new_cluster", []dyn.Location{{File: "c.yml", Line: 1, Column: 1}})
// We can't set num_workers to 0 explicitly in the typed configuration.
// Do it on the dyn.Value directly.
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.tasks[0].new_cluster.num_workers", dyn.V(0))
})
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Equal(t, diag.Diagnostics{
{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: []dyn.Location{{File: "c.yml", Line: 1, Column: 1}},
Paths: []dyn.Path{dyn.MustPathFromString("resources.jobs.foo.tasks[0].new_cluster")},
},
}, diags)
})
}
}
func TestValidateSingleNodeClusterFailForPipelineClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range failCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"foo": {
PipelineSpec: &pipelines.PipelineSpec{
Clusters: []pipelines.PipelineCluster{
{
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.pipelines.foo.clusters[0]", []dyn.Location{{File: "d.yml", Line: 1, Column: 1}})
// We can't set num_workers to 0 explicitly in the typed configuration.
// Do it on the dyn.Value directly.
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.pipelines.foo.clusters[0].num_workers", dyn.V(0))
})
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Equal(t, diag.Diagnostics{
{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: []dyn.Location{{File: "d.yml", Line: 1, Column: 1}},
Paths: []dyn.Path{dyn.MustPathFromString("resources.pipelines.foo.clusters[0]")},
},
}, diags)
})
}
}
func TestValidateSingleNodeClusterFailForJobForEachTaskCluster(t *testing.T) {
ctx := context.Background()
for _, tc := range failCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
ForEachTask: &jobs.ForEachTask{
Task: jobs.Task{
NewCluster: &compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
},
},
},
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.jobs.foo.tasks[0].for_each_task.task.new_cluster", []dyn.Location{{File: "e.yml", Line: 1, Column: 1}})
// We can't set num_workers to 0 explicitly in the typed configuration.
// Do it on the dyn.Value directly.
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.tasks[0].for_each_task.task.new_cluster.num_workers", dyn.V(0))
})
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Equal(t, diag.Diagnostics{
{
Severity: diag.Warning,
Summary: singleNodeWarningSummary,
Detail: singleNodeWarningDetail,
Locations: []dyn.Location{{File: "e.yml", Line: 1, Column: 1}},
Paths: []dyn.Path{dyn.MustPathFromString("resources.jobs.foo.tasks[0].for_each_task.task.new_cluster")},
},
}, diags)
})
}
}
func passCases() []struct {
name string
numWorkers *int
sparkConf map[string]string
customTags map[string]string
policyId string
} {
zero := 0
one := 1
return []struct {
name string
numWorkers *int
sparkConf map[string]string
customTags map[string]string
policyId string
}{
{
name: "single node cluster",
sparkConf: map[string]string{
"spark.databricks.cluster.profile": "singleNode",
"spark.master": "local[*]",
},
customTags: map[string]string{
"ResourceClass": "SingleNode",
},
numWorkers: &zero,
},
{
name: "num workers is not zero",
numWorkers: &one,
},
{
name: "num workers is not set",
},
{
name: "policy id is not empty",
policyId: "policy-abc",
numWorkers: &zero,
},
}
}
func TestValidateSingleNodeClusterPassInteractiveClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range passCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Clusters: map[string]*resources.Cluster{
"foo": {
ClusterSpec: &compute.ClusterSpec{
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
PolicyId: tc.policyId,
},
},
},
},
},
}
if tc.numWorkers != nil {
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.clusters.foo.num_workers", dyn.V(*tc.numWorkers))
})
}
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Empty(t, diags)
})
}
}
func TestValidateSingleNodeClusterPassJobClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range passCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
JobClusters: []jobs.JobCluster{
{
NewCluster: compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
PolicyId: tc.policyId,
},
},
},
},
},
},
},
},
}
if tc.numWorkers != nil {
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.job_clusters[0].new_cluster.num_workers", dyn.V(*tc.numWorkers))
})
}
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Empty(t, diags)
})
}
}
func TestValidateSingleNodeClusterPassJobTaskClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range passCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
NewCluster: &compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
PolicyId: tc.policyId,
},
},
},
},
},
},
},
},
}
if tc.numWorkers != nil {
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.tasks[0].new_cluster.num_workers", dyn.V(*tc.numWorkers))
})
}
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Empty(t, diags)
})
}
}
func TestValidateSingleNodeClusterPassPipelineClusters(t *testing.T) {
ctx := context.Background()
for _, tc := range passCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Pipelines: map[string]*resources.Pipeline{
"foo": {
PipelineSpec: &pipelines.PipelineSpec{
Clusters: []pipelines.PipelineCluster{
{
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
PolicyId: tc.policyId,
},
},
},
},
},
},
},
}
if tc.numWorkers != nil {
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.pipelines.foo.clusters[0].num_workers", dyn.V(*tc.numWorkers))
})
}
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Empty(t, diags)
})
}
}
func TestValidateSingleNodeClusterPassJobForEachTaskCluster(t *testing.T) {
ctx := context.Background()
for _, tc := range passCases() {
t.Run(tc.name, func(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"foo": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
ForEachTask: &jobs.ForEachTask{
Task: jobs.Task{
NewCluster: &compute.ClusterSpec{
ClusterName: "my_cluster",
SparkConf: tc.sparkConf,
CustomTags: tc.customTags,
PolicyId: tc.policyId,
},
},
},
},
},
},
},
},
},
},
}
if tc.numWorkers != nil {
bundletest.Mutate(t, b, func(v dyn.Value) (dyn.Value, error) {
return dyn.Set(v, "resources.jobs.foo.tasks[0].for_each_task.task.new_cluster.num_workers", dyn.V(*tc.numWorkers))
})
}
diags := bundle.ApplyReadOnly(ctx, bundle.ReadOnly(b), SingleNodeCluster())
assert.Empty(t, diags)
})
}
}

View File

@ -36,6 +36,7 @@ func (v *validate) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics
ValidateSyncPatterns(), ValidateSyncPatterns(),
JobTaskClusterSpec(), JobTaskClusterSpec(),
ValidateFolderPermissions(), ValidateFolderPermissions(),
SingleNodeCluster(),
)) ))
} }

View File

@ -1,11 +1,8 @@
// Code generated from OpenAPI specs by Databricks SDK Generator. DO NOT EDIT.
package variable package variable
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go"
) )
@ -25,6 +22,8 @@ type Lookup struct {
Metastore string `json:"metastore,omitempty"` Metastore string `json:"metastore,omitempty"`
NotificationDestination string `json:"notification_destination,omitempty"`
Pipeline string `json:"pipeline,omitempty"` Pipeline string `json:"pipeline,omitempty"`
Query string `json:"query,omitempty"` Query string `json:"query,omitempty"`
@ -34,323 +33,78 @@ type Lookup struct {
Warehouse string `json:"warehouse,omitempty"` Warehouse string `json:"warehouse,omitempty"`
} }
func LookupFromMap(m map[string]any) *Lookup { type resolver interface {
l := &Lookup{} // Resolve resolves the underlying entity's ID.
if v, ok := m["alert"]; ok { Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error)
l.Alert = v.(string)
// String returns a human-readable representation of the resolver.
String() string
}
func (l *Lookup) constructResolver() (resolver, error) {
var resolvers []resolver
if l.Alert != "" {
resolvers = append(resolvers, resolveAlert{name: l.Alert})
} }
if v, ok := m["cluster_policy"]; ok { if l.ClusterPolicy != "" {
l.ClusterPolicy = v.(string) resolvers = append(resolvers, resolveClusterPolicy{name: l.ClusterPolicy})
} }
if v, ok := m["cluster"]; ok { if l.Cluster != "" {
l.Cluster = v.(string) resolvers = append(resolvers, resolveCluster{name: l.Cluster})
} }
if v, ok := m["dashboard"]; ok { if l.Dashboard != "" {
l.Dashboard = v.(string) resolvers = append(resolvers, resolveDashboard{name: l.Dashboard})
} }
if v, ok := m["instance_pool"]; ok { if l.InstancePool != "" {
l.InstancePool = v.(string) resolvers = append(resolvers, resolveInstancePool{name: l.InstancePool})
} }
if v, ok := m["job"]; ok { if l.Job != "" {
l.Job = v.(string) resolvers = append(resolvers, resolveJob{name: l.Job})
} }
if v, ok := m["metastore"]; ok { if l.Metastore != "" {
l.Metastore = v.(string) resolvers = append(resolvers, resolveMetastore{name: l.Metastore})
} }
if v, ok := m["pipeline"]; ok { if l.NotificationDestination != "" {
l.Pipeline = v.(string) resolvers = append(resolvers, resolveNotificationDestination{name: l.NotificationDestination})
} }
if v, ok := m["query"]; ok { if l.Pipeline != "" {
l.Query = v.(string) resolvers = append(resolvers, resolvePipeline{name: l.Pipeline})
} }
if v, ok := m["service_principal"]; ok { if l.Query != "" {
l.ServicePrincipal = v.(string) resolvers = append(resolvers, resolveQuery{name: l.Query})
} }
if v, ok := m["warehouse"]; ok { if l.ServicePrincipal != "" {
l.Warehouse = v.(string) resolvers = append(resolvers, resolveServicePrincipal{name: l.ServicePrincipal})
}
if l.Warehouse != "" {
resolvers = append(resolvers, resolveWarehouse{name: l.Warehouse})
} }
return l switch len(resolvers) {
case 0:
return nil, fmt.Errorf("no valid lookup fields provided")
case 1:
return resolvers[0], nil
default:
return nil, fmt.Errorf("exactly one lookup field must be provided")
}
} }
func (l *Lookup) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { func (l *Lookup) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
if err := l.validate(); err != nil { r, err := l.constructResolver()
if err != nil {
return "", err return "", err
} }
r := allResolvers() return r.Resolve(ctx, w)
if l.Alert != "" {
return r.Alert(ctx, w, l.Alert)
}
if l.ClusterPolicy != "" {
return r.ClusterPolicy(ctx, w, l.ClusterPolicy)
}
if l.Cluster != "" {
return r.Cluster(ctx, w, l.Cluster)
}
if l.Dashboard != "" {
return r.Dashboard(ctx, w, l.Dashboard)
}
if l.InstancePool != "" {
return r.InstancePool(ctx, w, l.InstancePool)
}
if l.Job != "" {
return r.Job(ctx, w, l.Job)
}
if l.Metastore != "" {
return r.Metastore(ctx, w, l.Metastore)
}
if l.Pipeline != "" {
return r.Pipeline(ctx, w, l.Pipeline)
}
if l.Query != "" {
return r.Query(ctx, w, l.Query)
}
if l.ServicePrincipal != "" {
return r.ServicePrincipal(ctx, w, l.ServicePrincipal)
}
if l.Warehouse != "" {
return r.Warehouse(ctx, w, l.Warehouse)
}
return "", fmt.Errorf("no valid lookup fields provided")
} }
func (l *Lookup) String() string { func (l *Lookup) String() string {
if l.Alert != "" { r, _ := l.constructResolver()
return fmt.Sprintf("alert: %s", l.Alert) if r == nil {
} return ""
if l.ClusterPolicy != "" {
return fmt.Sprintf("cluster-policy: %s", l.ClusterPolicy)
}
if l.Cluster != "" {
return fmt.Sprintf("cluster: %s", l.Cluster)
}
if l.Dashboard != "" {
return fmt.Sprintf("dashboard: %s", l.Dashboard)
}
if l.InstancePool != "" {
return fmt.Sprintf("instance-pool: %s", l.InstancePool)
}
if l.Job != "" {
return fmt.Sprintf("job: %s", l.Job)
}
if l.Metastore != "" {
return fmt.Sprintf("metastore: %s", l.Metastore)
}
if l.Pipeline != "" {
return fmt.Sprintf("pipeline: %s", l.Pipeline)
}
if l.Query != "" {
return fmt.Sprintf("query: %s", l.Query)
}
if l.ServicePrincipal != "" {
return fmt.Sprintf("service-principal: %s", l.ServicePrincipal)
}
if l.Warehouse != "" {
return fmt.Sprintf("warehouse: %s", l.Warehouse)
} }
return "" return r.String()
}
func (l *Lookup) validate() error {
// Validate that only one field is set
count := 0
if l.Alert != "" {
count++
}
if l.ClusterPolicy != "" {
count++
}
if l.Cluster != "" {
count++
}
if l.Dashboard != "" {
count++
}
if l.InstancePool != "" {
count++
}
if l.Job != "" {
count++
}
if l.Metastore != "" {
count++
}
if l.Pipeline != "" {
count++
}
if l.Query != "" {
count++
}
if l.ServicePrincipal != "" {
count++
}
if l.Warehouse != "" {
count++
}
if count != 1 {
return fmt.Errorf("exactly one lookup field must be provided")
}
if strings.Contains(l.String(), "${var") {
return fmt.Errorf("lookup fields cannot contain variable references")
}
return nil
}
type resolverFunc func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error)
type resolvers struct {
Alert resolverFunc
ClusterPolicy resolverFunc
Cluster resolverFunc
Dashboard resolverFunc
InstancePool resolverFunc
Job resolverFunc
Metastore resolverFunc
Pipeline resolverFunc
Query resolverFunc
ServicePrincipal resolverFunc
Warehouse resolverFunc
}
func allResolvers() *resolvers {
r := &resolvers{}
r.Alert = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Alert"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Alerts.GetByDisplayName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
r.ClusterPolicy = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["ClusterPolicy"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.ClusterPolicies.GetByName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.PolicyId), nil
}
r.Cluster = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Cluster"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Clusters.GetByClusterName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.ClusterId), nil
}
r.Dashboard = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Dashboard"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Dashboards.GetByName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
r.InstancePool = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["InstancePool"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.InstancePools.GetByInstancePoolName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.InstancePoolId), nil
}
r.Job = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Job"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Jobs.GetBySettingsName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.JobId), nil
}
r.Metastore = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Metastore"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Metastores.GetByName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.MetastoreId), nil
}
r.Pipeline = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Pipeline"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Pipelines.GetByName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.PipelineId), nil
}
r.Query = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Query"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Queries.GetByDisplayName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
r.ServicePrincipal = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["ServicePrincipal"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.ServicePrincipals.GetByDisplayName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.ApplicationId), nil
}
r.Warehouse = func(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) {
fn, ok := lookupOverrides["Warehouse"]
if ok {
return fn(ctx, w, name)
}
entity, err := w.Warehouses.GetByName(ctx, name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
return r
} }

View File

@ -0,0 +1,60 @@
package variable
import (
"context"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLookup_Coverage(t *testing.T) {
var lookup Lookup
val := reflect.ValueOf(lookup)
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
if field.Kind() != reflect.String {
t.Fatalf("Field %s is not a string", typ.Field(i).Name)
}
fieldType := typ.Field(i)
t.Run(fieldType.Name, func(t *testing.T) {
// Use a fresh instance of the struct in each test
var lookup Lookup
// Set the field to a non-empty string
reflect.ValueOf(&lookup).Elem().Field(i).SetString("value")
// Test the [String] function
assert.NotEmpty(t, lookup.String())
})
}
}
func TestLookup_Empty(t *testing.T) {
var lookup Lookup
// Resolve returns an error when no fields are provided
_, err := lookup.Resolve(context.Background(), nil)
assert.ErrorContains(t, err, "no valid lookup fields provided")
// No string representation for an invalid lookup
assert.Empty(t, lookup.String())
}
func TestLookup_Multiple(t *testing.T) {
lookup := Lookup{
Alert: "alert",
Query: "query",
}
// Resolve returns an error when multiple fields are provided
_, err := lookup.Resolve(context.Background(), nil)
assert.ErrorContains(t, err, "exactly one lookup field must be provided")
// No string representation for an invalid lookup
assert.Empty(t, lookup.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveAlert struct {
name string
}
func (l resolveAlert) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Alerts.GetByDisplayName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
func (l resolveAlert) String() string {
return fmt.Sprintf("alert: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveAlert_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockAlertsAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "alert").
Return(&sql.ListAlertsResponseAlert{
Id: "1234",
}, nil)
ctx := context.Background()
l := resolveAlert{name: "alert"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "1234", result)
}
func TestResolveAlert_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockAlertsAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "alert").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveAlert{name: "alert"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveAlert_String(t *testing.T) {
l := resolveAlert{name: "name"}
assert.Equal(t, "alert: name", l.String())
}

View File

@ -8,13 +8,13 @@ import (
"github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/compute"
) )
var lookupOverrides = map[string]resolverFunc{ type resolveCluster struct {
"Cluster": resolveCluster, name string
} }
// We added a custom resolver for the cluster to add filtering for the cluster source when we list all clusters. // We added a custom resolver for the cluster to add filtering for the cluster source when we list all clusters.
// Without the filtering listing could take a very long time (5-10 mins) which leads to lookup timeouts. // Without the filtering listing could take a very long time (5-10 mins) which leads to lookup timeouts.
func resolveCluster(ctx context.Context, w *databricks.WorkspaceClient, name string) (string, error) { func (l resolveCluster) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
result, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{ result, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{
FilterBy: &compute.ListClustersFilterBy{ FilterBy: &compute.ListClustersFilterBy{
ClusterSources: []compute.ClusterSource{compute.ClusterSourceApi, compute.ClusterSourceUi}, ClusterSources: []compute.ClusterSource{compute.ClusterSourceApi, compute.ClusterSourceUi},
@ -30,6 +30,8 @@ func resolveCluster(ctx context.Context, w *databricks.WorkspaceClient, name str
key := v.ClusterName key := v.ClusterName
tmp[key] = append(tmp[key], v) tmp[key] = append(tmp[key], v)
} }
name := l.name
alternatives, ok := tmp[name] alternatives, ok := tmp[name]
if !ok || len(alternatives) == 0 { if !ok || len(alternatives) == 0 {
return "", fmt.Errorf("cluster named '%s' does not exist", name) return "", fmt.Errorf("cluster named '%s' does not exist", name)
@ -39,3 +41,7 @@ func resolveCluster(ctx context.Context, w *databricks.WorkspaceClient, name str
} }
return alternatives[0].ClusterId, nil return alternatives[0].ClusterId, nil
} }
func (l resolveCluster) String() string {
return fmt.Sprintf("cluster: %s", l.name)
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveClusterPolicy struct {
name string
}
func (l resolveClusterPolicy) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.ClusterPolicies.GetByName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.PolicyId), nil
}
func (l resolveClusterPolicy) String() string {
return fmt.Sprintf("cluster-policy: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveClusterPolicy_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockClusterPoliciesAPI()
api.EXPECT().
GetByName(mock.Anything, "policy").
Return(&compute.Policy{
PolicyId: "1234",
}, nil)
ctx := context.Background()
l := resolveClusterPolicy{name: "policy"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "1234", result)
}
func TestResolveClusterPolicy_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockClusterPoliciesAPI()
api.EXPECT().
GetByName(mock.Anything, "policy").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveClusterPolicy{name: "policy"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveClusterPolicy_String(t *testing.T) {
l := resolveClusterPolicy{name: "name"}
assert.Equal(t, "cluster-policy: name", l.String())
}

View File

@ -0,0 +1,50 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveCluster_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockClustersAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return([]compute.ClusterDetails{
{ClusterId: "1234", ClusterName: "cluster1"},
{ClusterId: "2345", ClusterName: "cluster2"},
}, nil)
ctx := context.Background()
l := resolveCluster{name: "cluster2"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "2345", result)
}
func TestResolveCluster_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockClustersAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return([]compute.ClusterDetails{}, nil)
ctx := context.Background()
l := resolveCluster{name: "cluster"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.Error(t, err)
assert.Contains(t, err.Error(), "cluster named 'cluster' does not exist")
}
func TestResolveCluster_String(t *testing.T) {
l := resolveCluster{name: "name"}
assert.Equal(t, "cluster: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveDashboard struct {
name string
}
func (l resolveDashboard) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Dashboards.GetByName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
func (l resolveDashboard) String() string {
return fmt.Sprintf("dashboard: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveDashboard_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockDashboardsAPI()
api.EXPECT().
GetByName(mock.Anything, "dashboard").
Return(&sql.Dashboard{
Id: "1234",
}, nil)
ctx := context.Background()
l := resolveDashboard{name: "dashboard"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "1234", result)
}
func TestResolveDashboard_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockDashboardsAPI()
api.EXPECT().
GetByName(mock.Anything, "dashboard").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveDashboard{name: "dashboard"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveDashboard_String(t *testing.T) {
l := resolveDashboard{name: "name"}
assert.Equal(t, "dashboard: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveInstancePool struct {
name string
}
func (l resolveInstancePool) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.InstancePools.GetByInstancePoolName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.InstancePoolId), nil
}
func (l resolveInstancePool) String() string {
return fmt.Sprintf("instance-pool: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveInstancePool_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockInstancePoolsAPI()
api.EXPECT().
GetByInstancePoolName(mock.Anything, "instance_pool").
Return(&compute.InstancePoolAndStats{
InstancePoolId: "5678",
}, nil)
ctx := context.Background()
l := resolveInstancePool{name: "instance_pool"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "5678", result)
}
func TestResolveInstancePool_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockInstancePoolsAPI()
api.EXPECT().
GetByInstancePoolName(mock.Anything, "instance_pool").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveInstancePool{name: "instance_pool"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveInstancePool_String(t *testing.T) {
l := resolveInstancePool{name: "name"}
assert.Equal(t, "instance-pool: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveJob struct {
name string
}
func (l resolveJob) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Jobs.GetBySettingsName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.JobId), nil
}
func (l resolveJob) String() string {
return fmt.Sprintf("job: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveJob_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockJobsAPI()
api.EXPECT().
GetBySettingsName(mock.Anything, "job").
Return(&jobs.BaseJob{
JobId: 5678,
}, nil)
ctx := context.Background()
l := resolveJob{name: "job"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "5678", result)
}
func TestResolveJob_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockJobsAPI()
api.EXPECT().
GetBySettingsName(mock.Anything, "job").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveJob{name: "job"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveJob_String(t *testing.T) {
l := resolveJob{name: "name"}
assert.Equal(t, "job: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveMetastore struct {
name string
}
func (l resolveMetastore) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Metastores.GetByName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.MetastoreId), nil
}
func (l resolveMetastore) String() string {
return fmt.Sprintf("metastore: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveMetastore_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockMetastoresAPI()
api.EXPECT().
GetByName(mock.Anything, "metastore").
Return(&catalog.MetastoreInfo{
MetastoreId: "abcd",
}, nil)
ctx := context.Background()
l := resolveMetastore{name: "metastore"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "abcd", result)
}
func TestResolveMetastore_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockMetastoresAPI()
api.EXPECT().
GetByName(mock.Anything, "metastore").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveMetastore{name: "metastore"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveMetastore_String(t *testing.T) {
l := resolveMetastore{name: "name"}
assert.Equal(t, "metastore: name", l.String())
}

View File

@ -0,0 +1,46 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/settings"
)
type resolveNotificationDestination struct {
name string
}
func (l resolveNotificationDestination) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
result, err := w.NotificationDestinations.ListAll(ctx, settings.ListNotificationDestinationsRequest{
// The default page size for this API is 20.
// We use a higher value to make fewer API calls.
PageSize: 200,
})
if err != nil {
return "", err
}
// Collect all notification destinations with the given name.
var entities []settings.ListNotificationDestinationsResult
for _, entity := range result {
if entity.DisplayName == l.name {
entities = append(entities, entity)
}
}
// Return the ID of the first matching notification destination.
switch len(entities) {
case 0:
return "", fmt.Errorf("notification destination named %q does not exist", l.name)
case 1:
return entities[0].Id, nil
default:
return "", fmt.Errorf("there are %d instances of clusters named %q", len(entities), l.name)
}
}
func (l resolveNotificationDestination) String() string {
return fmt.Sprintf("notification-destination: %s", l.name)
}

View File

@ -0,0 +1,82 @@
package variable
import (
"context"
"fmt"
"testing"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/settings"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveNotificationDestination_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockNotificationDestinationsAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return([]settings.ListNotificationDestinationsResult{
{Id: "1234", DisplayName: "destination"},
}, nil)
ctx := context.Background()
l := resolveNotificationDestination{name: "destination"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "1234", result)
}
func TestResolveNotificationDestination_ResolveError(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockNotificationDestinationsAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("bad"))
ctx := context.Background()
l := resolveNotificationDestination{name: "destination"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
assert.ErrorContains(t, err, "bad")
}
func TestResolveNotificationDestination_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockNotificationDestinationsAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return([]settings.ListNotificationDestinationsResult{}, nil)
ctx := context.Background()
l := resolveNotificationDestination{name: "destination"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.Error(t, err)
assert.ErrorContains(t, err, `notification destination named "destination" does not exist`)
}
func TestResolveNotificationDestination_ResolveMultiple(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockNotificationDestinationsAPI()
api.EXPECT().
ListAll(mock.Anything, mock.Anything).
Return([]settings.ListNotificationDestinationsResult{
{Id: "1234", DisplayName: "destination"},
{Id: "5678", DisplayName: "destination"},
}, nil)
ctx := context.Background()
l := resolveNotificationDestination{name: "destination"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.Error(t, err)
assert.ErrorContains(t, err, `there are 2 instances of clusters named "destination"`)
}
func TestResolveNotificationDestination_String(t *testing.T) {
l := resolveNotificationDestination{name: "name"}
assert.Equal(t, "notification-destination: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolvePipeline struct {
name string
}
func (l resolvePipeline) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Pipelines.GetByName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.PipelineId), nil
}
func (l resolvePipeline) String() string {
return fmt.Sprintf("pipeline: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/pipelines"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolvePipeline_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockPipelinesAPI()
api.EXPECT().
GetByName(mock.Anything, "pipeline").
Return(&pipelines.PipelineStateInfo{
PipelineId: "abcd",
}, nil)
ctx := context.Background()
l := resolvePipeline{name: "pipeline"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "abcd", result)
}
func TestResolvePipeline_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockPipelinesAPI()
api.EXPECT().
GetByName(mock.Anything, "pipeline").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolvePipeline{name: "pipeline"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolvePipeline_String(t *testing.T) {
l := resolvePipeline{name: "name"}
assert.Equal(t, "pipeline: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveQuery struct {
name string
}
func (l resolveQuery) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Queries.GetByDisplayName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
func (l resolveQuery) String() string {
return fmt.Sprintf("query: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveQuery_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockQueriesAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "query").
Return(&sql.ListQueryObjectsResponseQuery{
Id: "1234",
}, nil)
ctx := context.Background()
l := resolveQuery{name: "query"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "1234", result)
}
func TestResolveQuery_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockQueriesAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "query").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveQuery{name: "query"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveQuery_String(t *testing.T) {
l := resolveQuery{name: "name"}
assert.Equal(t, "query: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveServicePrincipal struct {
name string
}
func (l resolveServicePrincipal) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.ServicePrincipals.GetByDisplayName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.ApplicationId), nil
}
func (l resolveServicePrincipal) String() string {
return fmt.Sprintf("service-principal: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveServicePrincipal_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockServicePrincipalsAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "service-principal").
Return(&iam.ServicePrincipal{
ApplicationId: "5678",
}, nil)
ctx := context.Background()
l := resolveServicePrincipal{name: "service-principal"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "5678", result)
}
func TestResolveServicePrincipal_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockServicePrincipalsAPI()
api.EXPECT().
GetByDisplayName(mock.Anything, "service-principal").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveServicePrincipal{name: "service-principal"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveServicePrincipal_String(t *testing.T) {
l := resolveServicePrincipal{name: "name"}
assert.Equal(t, "service-principal: name", l.String())
}

View File

@ -0,0 +1,24 @@
package variable
import (
"context"
"fmt"
"github.com/databricks/databricks-sdk-go"
)
type resolveWarehouse struct {
name string
}
func (l resolveWarehouse) Resolve(ctx context.Context, w *databricks.WorkspaceClient) (string, error) {
entity, err := w.Warehouses.GetByName(ctx, l.name)
if err != nil {
return "", err
}
return fmt.Sprint(entity.Id), nil
}
func (l resolveWarehouse) String() string {
return fmt.Sprintf("warehouse: %s", l.name)
}

View File

@ -0,0 +1,49 @@
package variable
import (
"context"
"testing"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestResolveWarehouse_ResolveSuccess(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockWarehousesAPI()
api.EXPECT().
GetByName(mock.Anything, "warehouse").
Return(&sql.EndpointInfo{
Id: "abcd",
}, nil)
ctx := context.Background()
l := resolveWarehouse{name: "warehouse"}
result, err := l.Resolve(ctx, m.WorkspaceClient)
require.NoError(t, err)
assert.Equal(t, "abcd", result)
}
func TestResolveWarehouse_ResolveNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockWarehousesAPI()
api.EXPECT().
GetByName(mock.Anything, "warehouse").
Return(nil, &apierr.APIError{StatusCode: 404})
ctx := context.Background()
l := resolveWarehouse{name: "warehouse"}
_, err := l.Resolve(ctx, m.WorkspaceClient)
require.ErrorIs(t, err, apierr.ErrNotFound)
}
func TestResolveWarehouse_String(t *testing.T) {
l := resolveWarehouse{name: "name"}
assert.Equal(t, "warehouse: name", l.String())
}

View File

@ -7,6 +7,7 @@ import (
"io/fs" "io/fs"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/permissions" "github.com/databricks/cli/bundle/permissions"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
@ -23,6 +24,11 @@ func (m *upload) Name() string {
} }
func (m *upload) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { func (m *upload) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
if config.IsExplicitlyEnabled(b.Config.Presets.SourceLinkedDeployment) {
cmdio.LogString(ctx, "Source-linked deployment is enabled. Deployed resources reference the source files in your working tree instead of separate copies.")
return nil
}
cmdio.LogString(ctx, fmt.Sprintf("Uploading bundle files to %s...", b.Config.Workspace.FilePath)) cmdio.LogString(ctx, fmt.Sprintf("Uploading bundle files to %s...", b.Config.Workspace.FilePath))
opts, err := GetSyncOptions(ctx, bundle.ReadOnly(b)) opts, err := GetSyncOptions(ctx, bundle.ReadOnly(b))
if err != nil { if err != nil {

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/libraries" "github.com/databricks/cli/bundle/libraries"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
@ -22,6 +23,9 @@ func WrapperWarning() bundle.Mutator {
func (m *wrapperWarning) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { func (m *wrapperWarning) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
if isPythonWheelWrapperOn(b) { if isPythonWheelWrapperOn(b) {
if config.IsExplicitlyEnabled(b.Config.Presets.SourceLinkedDeployment) {
return diag.Warningf("Python wheel notebook wrapper is not available when using source-linked deployment mode. You can disable this mode by setting 'presets.source_linked_deployment: false'")
}
return nil return nil
} }

View File

@ -31,6 +31,7 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
cmd.AddCommand(newProfilesCommand()) cmd.AddCommand(newProfilesCommand())
cmd.AddCommand(newTokenCommand(&perisistentAuth)) cmd.AddCommand(newTokenCommand(&perisistentAuth))
cmd.AddCommand(newDescribeCommand()) cmd.AddCommand(newDescribeCommand())
cmd.AddCommand(newLogoutCommand(&perisistentAuth))
return cmd return cmd
} }

110
cmd/auth/logout.go Normal file
View File

@ -0,0 +1,110 @@
package auth
import (
"context"
"errors"
"fmt"
"io/fs"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/auth/cache"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
)
type logoutSession struct {
profile string
file config.File
persistentAuth *auth.PersistentAuth
}
func (l *logoutSession) load(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth) error {
l.profile = profileName
l.persistentAuth = persistentAuth
iniFile, err := profile.DefaultProfiler.Get(ctx)
if errors.Is(err, fs.ErrNotExist) {
return err
} else if err != nil {
return fmt.Errorf("cannot parse config file: %w", err)
}
l.file = *iniFile
if err := l.setHostAndAccountIdFromProfile(); err != nil {
return err
}
return nil
}
func (l *logoutSession) setHostAndAccountIdFromProfile() error {
sectionMap, err := l.getConfigSectionMap()
if err != nil {
return err
}
if sectionMap["host"] == "" {
return fmt.Errorf("no host configured for profile %s", l.profile)
}
l.persistentAuth.Host = sectionMap["host"]
l.persistentAuth.AccountID = sectionMap["account_id"]
return nil
}
func (l *logoutSession) getConfigSectionMap() (map[string]string, error) {
section, err := l.file.GetSection(l.profile)
if err != nil {
return map[string]string{}, fmt.Errorf("profile does not exist in config file: %w", err)
}
return section.KeysHash(), nil
}
// clear token from ~/.databricks/token-cache.json
func (l *logoutSession) clearTokenCache(ctx context.Context) error {
return l.persistentAuth.ClearToken(ctx)
}
func newLogoutCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
cmd := &cobra.Command{
Use: "logout [PROFILE]",
Short: "Logout from specified profile",
Long: "Removes the OAuth token from the token-cache",
}
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
profileNameFromFlag := cmd.Flag("profile").Value.String()
// If both [PROFILE] and --profile are provided, return an error.
if len(args) > 0 && profileNameFromFlag != "" {
return fmt.Errorf("please only provide a profile as an argument or a flag, not both")
}
// Determine the profile name from either args or the flag.
profileName := profileNameFromFlag
if len(args) > 0 {
profileName = args[0]
}
// If the user has not specified a profile name, prompt for one.
if profileName == "" {
var err error
profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
if err != nil {
return err
}
}
defer persistentAuth.Close()
logoutSession := &logoutSession{}
err := logoutSession.load(ctx, profileName, persistentAuth)
if err != nil {
return err
}
err = logoutSession.clearTokenCache(ctx)
if err != nil {
if errors.Is(err, cache.ErrNotConfigured) {
// It is OK to not have OAuth configured
} else {
return err
}
}
cmdio.LogString(ctx, fmt.Sprintf("Profile %s is logged out", profileName))
return nil
}
return cmd
}

62
cmd/auth/logout_test.go Normal file
View File

@ -0,0 +1,62 @@
package auth
import (
"context"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/databricks-sdk-go/config"
)
func TestLogout_setHostAndAccountIdFromProfile(t *testing.T) {
ctx := context.Background()
path := filepath.Join(t.TempDir(), "databrickscfg")
err := databrickscfg.SaveToProfile(ctx, &config.Config{
ConfigFile: path,
Profile: "abc",
Host: "https://foo",
Token: "xyz",
})
require.NoError(t, err)
iniFile, err := config.LoadFile(path)
require.NoError(t, err)
logout := &logoutSession{
profile: "abc",
file: *iniFile,
persistentAuth: &auth.PersistentAuth{},
}
err = logout.setHostAndAccountIdFromProfile()
assert.NoError(t, err)
assert.Equal(t, logout.persistentAuth.Host, "https://foo")
assert.Empty(t, logout.persistentAuth.AccountID)
}
func TestLogout_getConfigSectionMap(t *testing.T) {
ctx := context.Background()
path := filepath.Join(t.TempDir(), "databrickscfg")
err := databrickscfg.SaveToProfile(ctx, &config.Config{
ConfigFile: path,
Profile: "abc",
Host: "https://foo",
Token: "xyz",
})
require.NoError(t, err)
iniFile, err := config.LoadFile(path)
require.NoError(t, err)
logout := &logoutSession{
profile: "abc",
file: *iniFile,
persistentAuth: &auth.PersistentAuth{},
}
configSectionMap, err := logout.getConfigSectionMap()
assert.NoError(t, err)
assert.Equal(t, configSectionMap["host"], "https://foo")
assert.Equal(t, configSectionMap["token"], "xyz")
}

View File

@ -3,8 +3,10 @@ package generate
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -90,7 +92,7 @@ func TestGeneratePipelineCommand(t *testing.T) {
err := cmd.RunE(cmd, []string{}) err := cmd.RunE(cmd, []string{})
require.NoError(t, err) require.NoError(t, err)
data, err := os.ReadFile(filepath.Join(configDir, "test_pipeline.yml")) data, err := os.ReadFile(filepath.Join(configDir, "test_pipeline.pipeline.yml"))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fmt.Sprintf(`resources: require.Equal(t, fmt.Sprintf(`resources:
pipelines: pipelines:
@ -186,7 +188,123 @@ func TestGenerateJobCommand(t *testing.T) {
err := cmd.RunE(cmd, []string{}) err := cmd.RunE(cmd, []string{})
require.NoError(t, err) require.NoError(t, err)
data, err := os.ReadFile(filepath.Join(configDir, "test_job.yml")) data, err := os.ReadFile(filepath.Join(configDir, "test_job.job.yml"))
require.NoError(t, err)
require.Equal(t, fmt.Sprintf(`resources:
jobs:
test_job:
name: test-job
job_clusters:
- new_cluster:
custom_tags:
"Tag1": "24X7-1234"
- new_cluster:
spark_conf:
"spark.databricks.delta.preview.enabled": "true"
tasks:
- task_key: notebook_task
notebook_task:
notebook_path: %s
parameters:
- name: empty
default: ""
`, filepath.Join("..", "src", "notebook.py")), string(data))
data, err = os.ReadFile(filepath.Join(srcDir, "notebook.py"))
require.NoError(t, err)
require.Equal(t, "# Databricks notebook source\nNotebook content", string(data))
}
func touchEmptyFile(t *testing.T, path string) {
err := os.MkdirAll(filepath.Dir(path), 0700)
require.NoError(t, err)
f, err := os.Create(path)
require.NoError(t, err)
f.Close()
}
func TestGenerateJobCommandOldFileRename(t *testing.T) {
cmd := NewGenerateJobCommand()
root := t.TempDir()
b := &bundle.Bundle{
BundleRootPath: root,
}
m := mocks.NewMockWorkspaceClient(t)
b.SetWorkpaceClient(m.WorkspaceClient)
jobsApi := m.GetMockJobsAPI()
jobsApi.EXPECT().Get(mock.Anything, jobs.GetJobRequest{JobId: 1234}).Return(&jobs.Job{
Settings: &jobs.JobSettings{
Name: "test-job",
JobClusters: []jobs.JobCluster{
{NewCluster: compute.ClusterSpec{
CustomTags: map[string]string{
"Tag1": "24X7-1234",
},
}},
{NewCluster: compute.ClusterSpec{
SparkConf: map[string]string{
"spark.databricks.delta.preview.enabled": "true",
},
}},
},
Tasks: []jobs.Task{
{
TaskKey: "notebook_task",
NotebookTask: &jobs.NotebookTask{
NotebookPath: "/test/notebook",
},
},
},
Parameters: []jobs.JobParameterDefinition{
{
Name: "empty",
Default: "",
},
},
},
}, nil)
workspaceApi := m.GetMockWorkspaceAPI()
workspaceApi.EXPECT().GetStatusByPath(mock.Anything, "/test/notebook").Return(&workspace.ObjectInfo{
ObjectType: workspace.ObjectTypeNotebook,
Language: workspace.LanguagePython,
Path: "/test/notebook",
}, nil)
notebookContent := io.NopCloser(bytes.NewBufferString("# Databricks notebook source\nNotebook content"))
workspaceApi.EXPECT().Download(mock.Anything, "/test/notebook", mock.Anything).Return(notebookContent, nil)
cmd.SetContext(bundle.Context(context.Background(), b))
cmd.Flag("existing-job-id").Value.Set("1234")
configDir := filepath.Join(root, "resources")
cmd.Flag("config-dir").Value.Set(configDir)
srcDir := filepath.Join(root, "src")
cmd.Flag("source-dir").Value.Set(srcDir)
var key string
cmd.Flags().StringVar(&key, "key", "test_job", "")
// Create an old generated file first
oldFilename := filepath.Join(configDir, "test_job.yml")
touchEmptyFile(t, oldFilename)
// Having an existing files require --force flag to regenerate them
cmd.Flag("force").Value.Set("true")
err := cmd.RunE(cmd, []string{})
require.NoError(t, err)
// Make sure file do not exists after the run
_, err = os.Stat(oldFilename)
require.True(t, errors.Is(err, fs.ErrNotExist))
data, err := os.ReadFile(filepath.Join(configDir, "test_job.job.yml"))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fmt.Sprintf(`resources: require.Equal(t, fmt.Sprintf(`resources:

View File

@ -1,7 +1,9 @@
package generate package generate
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
@ -83,7 +85,17 @@ func NewGenerateJobCommand() *cobra.Command {
return err return err
} }
filename := filepath.Join(configDir, fmt.Sprintf("%s.yml", jobKey)) oldFilename := filepath.Join(configDir, fmt.Sprintf("%s.yml", jobKey))
filename := filepath.Join(configDir, fmt.Sprintf("%s.job.yml", jobKey))
// User might continuously run generate command to update their bundle jobs with any changes made in Databricks UI.
// Due to changing in the generated file names, we need to first rename existing resource file to the new name.
// Otherwise users can end up with duplicated resources.
err = os.Rename(oldFilename, filename)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to rename file %s. DABs uses the resource type as a sub-extension for generated content, please rename it to %s, err: %w", oldFilename, filename, err)
}
saver := yamlsaver.NewSaverWithStyle(map[string]yaml.Style{ saver := yamlsaver.NewSaverWithStyle(map[string]yaml.Style{
// Including all JobSettings and nested fields which are map[string]string type // Including all JobSettings and nested fields which are map[string]string type
"spark_conf": yaml.DoubleQuotedStyle, "spark_conf": yaml.DoubleQuotedStyle,

View File

@ -1,7 +1,9 @@
package generate package generate
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
@ -83,7 +85,17 @@ func NewGeneratePipelineCommand() *cobra.Command {
return err return err
} }
filename := filepath.Join(configDir, fmt.Sprintf("%s.yml", pipelineKey)) oldFilename := filepath.Join(configDir, fmt.Sprintf("%s.yml", pipelineKey))
filename := filepath.Join(configDir, fmt.Sprintf("%s.pipeline.yml", pipelineKey))
// User might continuously run generate command to update their bundle jobs with any changes made in Databricks UI.
// Due to changing in the generated file names, we need to first rename existing resource file to the new name.
// Otherwise users can end up with duplicated resources.
err = os.Rename(oldFilename, filename)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to rename file %s. DABs uses the resource type as a sub-extension for generated content, please rename it to %s, err: %w", oldFilename, filename, err)
}
saver := yamlsaver.NewSaverWithStyle( saver := yamlsaver.NewSaverWithStyle(
// Including all PipelineSpec and nested fields which are map[string]string type // Including all PipelineSpec and nested fields which are map[string]string type
map[string]yaml.Style{ map[string]yaml.Style{

View File

@ -1,8 +1,10 @@
package bundle package bundle
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@ -10,6 +12,8 @@ import (
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/dbr"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/git"
"github.com/databricks/cli/libs/template" "github.com/databricks/cli/libs/template"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -109,6 +113,24 @@ func getUrlForNativeTemplate(name string) string {
return "" return ""
} }
func getFsForNativeTemplate(name string) (fs.FS, error) {
builtin, err := template.Builtin()
if err != nil {
return nil, err
}
// If this is a built-in template, the return value will be non-nil.
var templateFS fs.FS
for _, entry := range builtin {
if entry.Name == name {
templateFS = entry.FS
break
}
}
return templateFS, nil
}
func isRepoUrl(url string) bool { func isRepoUrl(url string) bool {
result := false result := false
for _, prefix := range gitUrlPrefixes { for _, prefix := range gitUrlPrefixes {
@ -128,6 +150,26 @@ func repoName(url string) string {
return parts[len(parts)-1] return parts[len(parts)-1]
} }
func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) {
outputDir, err := filepath.Abs(outputDir)
if err != nil {
return nil, err
}
// If the CLI is running on DBR and we're writing to the workspace file system,
// use the extension-aware workspace filesystem filer to instantiate the template.
//
// It is not possible to write notebooks through the workspace filesystem's FUSE mount.
// Therefore this is the only way we can initialize templates that contain notebooks
// when running the CLI on DBR and initializing a template to the workspace.
//
if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) {
return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir)
}
return filer.NewLocalClient(outputDir)
}
func newInitCommand() *cobra.Command { func newInitCommand() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "init [TEMPLATE_PATH]", Use: "init [TEMPLATE_PATH]",
@ -182,6 +224,11 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
templatePath = getNativeTemplateByDescription(description) templatePath = getNativeTemplateByDescription(description)
} }
outputFiler, err := constructOutputFiler(ctx, outputDir)
if err != nil {
return err
}
if templatePath == customTemplate { if templatePath == customTemplate {
cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.")
cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.")
@ -198,9 +245,20 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
if templateDir != "" { if templateDir != "" {
return errors.New("--template-dir can only be used with a Git repository URL") return errors.New("--template-dir can only be used with a Git repository URL")
} }
templateFS, err := getFsForNativeTemplate(templatePath)
if err != nil {
return err
}
// If this is not a built-in template, then it must be a local file system path.
if templateFS == nil {
templateFS = os.DirFS(templatePath)
}
// skip downloading the repo because input arg is not a URL. We assume // skip downloading the repo because input arg is not a URL. We assume
// it's a path on the local file system in that case // it's a path on the local file system in that case
return template.Materialize(ctx, configFile, templatePath, outputDir) return template.Materialize(ctx, configFile, templateFS, outputFiler)
} }
// Create a temporary directory with the name of the repository. The '*' // Create a temporary directory with the name of the repository. The '*'
@ -224,7 +282,8 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf
// Clean up downloaded repository once the template is materialized. // Clean up downloaded repository once the template is materialized.
defer os.RemoveAll(repoDir) defer os.RemoveAll(repoDir)
return template.Materialize(ctx, configFile, filepath.Join(repoDir, templateDir), outputDir) templateFS := os.DirFS(filepath.Join(repoDir, templateDir))
return template.Materialize(ctx, configFile, templateFS, outputFiler)
} }
return cmd return cmd
} }

View File

@ -166,7 +166,7 @@ func TestAccGenerateAndBind(t *testing.T) {
_, err = os.Stat(filepath.Join(bundleRoot, "src", "test.py")) _, err = os.Stat(filepath.Join(bundleRoot, "src", "test.py"))
require.NoError(t, err) require.NoError(t, err)
matches, err := filepath.Glob(filepath.Join(bundleRoot, "resources", "test_job_key.yml")) matches, err := filepath.Glob(filepath.Join(bundleRoot, "resources", "test_job_key.job.yml"))
require.NoError(t, err) require.NoError(t, err)
require.Len(t, matches, 1) require.Len(t, matches, 1)

View File

@ -11,6 +11,11 @@
"node_type_id": { "node_type_id": {
"type": "string", "type": "string",
"description": "Node type id for job cluster" "description": "Node type id for job cluster"
},
"root_path": {
"type": "string",
"description": "Root path to deploy bundle to",
"default": ""
} }
} }
} }

View File

@ -2,7 +2,11 @@ bundle:
name: basic name: basic
workspace: workspace:
{{ if .root_path }}
root_path: "{{.root_path}}/.bundle/{{.unique_id}}"
{{ else }}
root_path: "~/.bundle/{{.unique_id}}" root_path: "~/.bundle/{{.unique_id}}"
{{ end }}
resources: resources:
jobs: jobs:

View File

@ -1,2 +0,0 @@
bundle:
name: abc

View File

@ -1,5 +1,8 @@
bundle: bundle:
name: "bundle-playground" name: recreate-pipeline
workspace:
root_path: "~/.bundle/{{.unique_id}}"
variables: variables:
catalog: catalog:

View File

@ -1,5 +1,8 @@
bundle: bundle:
name: "bundle-playground" name: uc-schema
workspace:
root_path: "~/.bundle/{{.unique_id}}"
resources: resources:
pipelines: pipelines:

View File

@ -0,0 +1,38 @@
package bundle
import (
"fmt"
"testing"
"github.com/databricks/cli/internal"
"github.com/databricks/cli/internal/acc"
"github.com/databricks/cli/libs/env"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestAccDeployBasicToSharedWorkspacePath(t *testing.T) {
ctx, wt := acc.WorkspaceTest(t)
nodeTypeId := internal.GetNodeTypeId(env.Get(ctx, "CLOUD_ENV"))
uniqueId := uuid.New().String()
currentUser, err := wt.W.CurrentUser.Me(ctx)
require.NoError(t, err)
bundleRoot, err := initTestTemplate(t, ctx, "basic", map[string]any{
"unique_id": uniqueId,
"node_type_id": nodeTypeId,
"spark_version": defaultSparkVersion,
"root_path": fmt.Sprintf("/Shared/%s", currentUser.UserName),
})
require.NoError(t, err)
t.Cleanup(func() {
err = destroyBundle(wt.T, ctx, bundleRoot)
require.NoError(wt.T, err)
})
err = deployBundle(wt.T, ctx, bundleRoot)
require.NoError(wt.T, err)
}

View File

@ -16,6 +16,7 @@ import (
"github.com/databricks/cli/internal" "github.com/databricks/cli/internal"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/flags"
"github.com/databricks/cli/libs/template" "github.com/databricks/cli/libs/template"
"github.com/databricks/cli/libs/vfs" "github.com/databricks/cli/libs/vfs"
@ -42,7 +43,9 @@ func initTestTemplateWithBundleRoot(t *testing.T, ctx context.Context, templateN
cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles")
ctx = cmdio.InContext(ctx, cmd) ctx = cmdio.InContext(ctx, cmd)
err = template.Materialize(ctx, configFilePath, templateRoot, bundleRoot) out, err := filer.NewLocalClient(bundleRoot)
require.NoError(t, err)
err = template.Materialize(ctx, configFilePath, os.DirFS(templateRoot), out)
return bundleRoot, err return bundleRoot, err
} }

View File

@ -723,6 +723,63 @@ func TestAccWorkspaceFilesExtensionsDirectoriesAreNotNotebooks(t *testing.T) {
assert.ErrorIs(t, err, fs.ErrNotExist) assert.ErrorIs(t, err, fs.ErrNotExist)
} }
func TestAccWorkspaceFilesExtensionsNotebooksAreNotReadAsFiles(t *testing.T) {
t.Parallel()
ctx := context.Background()
wf, _ := setupWsfsExtensionsFiler(t)
// Create a notebook
err := wf.Write(ctx, "foo.ipynb", strings.NewReader(readFile(t, "testdata/notebooks/py1.ipynb")))
require.NoError(t, err)
// Reading foo should fail. Even though the WSFS name for the notebook is foo
// reading the notebook should only work with the .ipynb extension.
_, err = wf.Read(ctx, "foo")
assert.ErrorIs(t, err, fs.ErrNotExist)
_, err = wf.Read(ctx, "foo.ipynb")
assert.NoError(t, err)
}
func TestAccWorkspaceFilesExtensionsNotebooksAreNotStatAsFiles(t *testing.T) {
t.Parallel()
ctx := context.Background()
wf, _ := setupWsfsExtensionsFiler(t)
// Create a notebook
err := wf.Write(ctx, "foo.ipynb", strings.NewReader(readFile(t, "testdata/notebooks/py1.ipynb")))
require.NoError(t, err)
// Stating foo should fail. Even though the WSFS name for the notebook is foo
// stating the notebook should only work with the .ipynb extension.
_, err = wf.Stat(ctx, "foo")
assert.ErrorIs(t, err, fs.ErrNotExist)
_, err = wf.Stat(ctx, "foo.ipynb")
assert.NoError(t, err)
}
func TestAccWorkspaceFilesExtensionsNotebooksAreNotDeletedAsFiles(t *testing.T) {
t.Parallel()
ctx := context.Background()
wf, _ := setupWsfsExtensionsFiler(t)
// Create a notebook
err := wf.Write(ctx, "foo.ipynb", strings.NewReader(readFile(t, "testdata/notebooks/py1.ipynb")))
require.NoError(t, err)
// Deleting foo should fail. Even though the WSFS name for the notebook is foo
// deleting the notebook should only work with the .ipynb extension.
err = wf.Delete(ctx, "foo")
assert.ErrorIs(t, err, fs.ErrNotExist)
err = wf.Delete(ctx, "foo.ipynb")
assert.NoError(t, err)
}
func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) { func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -97,7 +97,7 @@ func TestAccBundleInitOnMlopsStacks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
job, err := w.Jobs.GetByJobId(context.Background(), batchJobId) job, err := w.Jobs.GetByJobId(context.Background(), batchJobId)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, fmt.Sprintf("dev-%s-batch-inference-job", projectName), job.Settings.Name) assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName))
} }
func TestAccBundleInitHelpers(t *testing.T) { func TestAccBundleInitHelpers(t *testing.T) {

View File

@ -9,6 +9,7 @@ import (
type TokenCache interface { type TokenCache interface {
Store(key string, t *oauth2.Token) error Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error) Lookup(key string) (*oauth2.Token, error)
Delete(key string) error
} }
var tokenCache int var tokenCache int

View File

@ -52,11 +52,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error {
c.Tokens = map[string]*oauth2.Token{} c.Tokens = map[string]*oauth2.Token{}
} }
c.Tokens[key] = t c.Tokens[key] = t
raw, err := json.MarshalIndent(c, "", " ") return c.write()
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
} }
func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
@ -73,6 +69,24 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) {
return t, nil return t, nil
} }
func (c *FileTokenCache) Delete(key string) error {
err := c.load()
if errors.Is(err, fs.ErrNotExist) {
return ErrNotConfigured
} else if err != nil {
return fmt.Errorf("load: %w", err)
}
if c.Tokens == nil {
c.Tokens = map[string]*oauth2.Token{}
}
_, ok := c.Tokens[key]
if !ok {
return ErrNotConfigured
}
delete(c.Tokens, key)
return c.write()
}
func (c *FileTokenCache) location() (string, error) { func (c *FileTokenCache) location() (string, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
@ -105,4 +119,12 @@ func (c *FileTokenCache) load() error {
return nil return nil
} }
func (c *FileTokenCache) write() error {
raw, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return os.WriteFile(c.fileLocation, raw, ownerReadWrite)
}
var _ TokenCache = (*FileTokenCache)(nil) var _ TokenCache = (*FileTokenCache)(nil)

View File

@ -1,6 +1,7 @@
package cache package cache
import ( import (
"encoding/json"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -103,3 +104,64 @@ func TestStoreOnDev(t *testing.T) {
// macOS: read-only file system // macOS: read-only file system
assert.Error(t, err) assert.Error(t, err)
} }
func TestStoreAndDeleteKey(t *testing.T) {
setup(t)
c := &FileTokenCache{}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})
require.NoError(t, err)
err = c.Store("y", &oauth2.Token{
AccessToken: "bcd",
})
require.NoError(t, err)
l := &FileTokenCache{}
err = l.Delete("x")
require.NoError(t, err)
assert.Equal(t, 1, len(l.Tokens))
_, err = l.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
tok, err := l.Lookup("y")
require.NoError(t, err)
assert.Equal(t, "bcd", tok.AccessToken)
}
func TestDeleteKeyNotExist(t *testing.T) {
c := &FileTokenCache{
Tokens: map[string]*oauth2.Token{},
}
err := c.Delete("x")
assert.Equal(t, ErrNotConfigured, err)
_, err = c.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
}
func TestWrite(t *testing.T) {
tempFile := filepath.Join(t.TempDir(), "token-cache.json")
tokenMap := map[string]*oauth2.Token{}
token := &oauth2.Token{
AccessToken: "some-access-token",
}
tokenMap["test"] = token
cache := &FileTokenCache{
fileLocation: tempFile,
Tokens: tokenMap,
}
err := cache.write()
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
require.NoError(t, err)
expected, _ := json.MarshalIndent(&cache, "", " ")
assert.Equal(t, content, expected)
}

View File

@ -23,4 +23,14 @@ func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error {
return nil return nil
} }
// Delete implements TokenCache.
func (i *InMemoryTokenCache) Delete(key string) error {
_, ok := i.Tokens[key]
if !ok {
return ErrNotConfigured
}
delete(i.Tokens, key)
return nil
}
var _ TokenCache = (*InMemoryTokenCache)(nil) var _ TokenCache = (*InMemoryTokenCache)(nil)

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -42,3 +43,40 @@ func TestInMemoryCacheStore(t *testing.T) {
assert.Equal(t, res, token) assert.Equal(t, res, token)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestInMemoryDeleteKey(t *testing.T) {
c := &InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{},
}
err := c.Store("x", &oauth2.Token{
AccessToken: "abc",
})
require.NoError(t, err)
err = c.Store("y", &oauth2.Token{
AccessToken: "bcd",
})
require.NoError(t, err)
err = c.Delete("x")
require.NoError(t, err)
assert.Equal(t, 1, len(c.Tokens))
_, err = c.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
tok, err := c.Lookup("y")
require.NoError(t, err)
assert.Equal(t, "bcd", tok.AccessToken)
}
func TestInMemoryDeleteKeyNotExist(t *testing.T) {
c := &InMemoryTokenCache{
Tokens: map[string]*oauth2.Token{},
}
err := c.Delete("x")
assert.Equal(t, ErrNotConfigured, err)
_, err = c.Lookup("x")
assert.Equal(t, ErrNotConfigured, err)
}

View File

@ -144,6 +144,18 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
return nil return nil
} }
func (a *PersistentAuth) ClearToken(ctx context.Context) error {
if a.Host == "" && a.AccountID == "" {
return ErrFetchCredentials
}
if a.cache == nil {
a.cache = cache.GetTokenCache(ctx)
}
// lookup token identified by host (and possibly the account id)
key := a.key()
return a.cache.Delete(key)
}
// This function cleans up the host URL by only retaining the scheme and the host. // This function cleans up the host URL by only retaining the scheme and the host.
// This function thus removes any path, query arguments, or fragments from the URL. // This function thus removes any path, query arguments, or fragments from the URL.
func (a *PersistentAuth) cleanHost() { func (a *PersistentAuth) cleanHost() {

View File

@ -55,6 +55,7 @@ func TestOidcForWorkspace(t *testing.T) {
type tokenCacheMock struct { type tokenCacheMock struct {
store func(key string, t *oauth2.Token) error store func(key string, t *oauth2.Token) error
lookup func(key string) (*oauth2.Token, error) lookup func(key string) (*oauth2.Token, error)
delete func(key string) error
} }
func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error {
@ -71,6 +72,13 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) {
return m.lookup(key) return m.lookup(key)
} }
func (m *tokenCacheMock) Delete(key string) error {
if m.delete == nil {
panic("no deleteKey mock")
}
return m.delete(key)
}
func TestLoad(t *testing.T) { func TestLoad(t *testing.T) {
p := &PersistentAuth{ p := &PersistentAuth{
Host: "abc", Host: "abc",
@ -229,6 +237,52 @@ func TestChallengeFailed(t *testing.T) {
}) })
} }
func TestClearToken(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return &oauth2.Token{}, ErrNotConfigured
},
delete: func(key string) error {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return nil
},
},
}
defer p.Close()
err := p.ClearToken(context.Background())
assert.NoError(t, err)
key := p.key()
_, err = p.cache.Lookup(key)
assert.Equal(t, ErrNotConfigured, err)
}
func TestClearTokenNotExist(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
AccountID: "xyz",
cache: &tokenCacheMock{
lookup: func(key string) (*oauth2.Token, error) {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return &oauth2.Token{}, ErrNotConfigured
},
delete: func(key string) error {
assert.Equal(t, "https://abc/oidc/accounts/xyz", key)
return ErrNotConfigured
},
},
}
defer p.Close()
err := p.ClearToken(context.Background())
assert.Equal(t, ErrNotConfigured, err)
key := p.key()
_, err = p.cache.Lookup(key)
assert.Equal(t, ErrNotConfigured, err)
}
func TestPersistentAuthCleanHost(t *testing.T) { func TestPersistentAuthCleanHost(t *testing.T) {
for _, tcases := range []struct { for _, tcases := range []struct {
in string in string

View File

@ -7,13 +7,24 @@ import (
"io/fs" "io/fs"
) )
// WriteMode captures intent when writing a file.
//
// The first 9 bits are reserved for the [fs.FileMode] permission bits.
// These are used only by the local filer implementation and have
// no effect for the other implementations.
type WriteMode int type WriteMode int
// writeModePerm is a mask to extract permission bits from a WriteMode.
const writeModePerm = WriteMode(fs.ModePerm)
const ( const (
OverwriteIfExists WriteMode = 1 << iota // Note: these constants are defined as powers of 2 to support combining them using a bit-wise OR.
// They starts from the 10th bit (permission mask + 1) to avoid conflicts with the permission bits.
OverwriteIfExists WriteMode = (writeModePerm + 1) << iota
CreateParentDirectories CreateParentDirectories
) )
// DeleteMode captures intent when deleting a file.
type DeleteMode int type DeleteMode int
const ( const (

12
libs/filer/filer_test.go Normal file
View File

@ -0,0 +1,12 @@
package filer
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteMode(t *testing.T) {
assert.Equal(t, 512, int(OverwriteIfExists))
assert.Equal(t, 1024, int(CreateParentDirectories))
}

View File

@ -28,6 +28,15 @@ func (w *LocalClient) Write(ctx context.Context, name string, reader io.Reader,
return err return err
} }
// Retrieve permission mask from the [WriteMode], if present.
perm := fs.FileMode(0644)
for _, m := range mode {
bits := m & writeModePerm
if bits != 0 {
perm = fs.FileMode(bits)
}
}
flags := os.O_WRONLY | os.O_CREATE flags := os.O_WRONLY | os.O_CREATE
if slices.Contains(mode, OverwriteIfExists) { if slices.Contains(mode, OverwriteIfExists) {
flags |= os.O_TRUNC flags |= os.O_TRUNC
@ -35,7 +44,7 @@ func (w *LocalClient) Write(ctx context.Context, name string, reader io.Reader,
flags |= os.O_EXCL flags |= os.O_EXCL
} }
f, err := os.OpenFile(absPath, flags, 0644) f, err := os.OpenFile(absPath, flags, perm)
if errors.Is(err, fs.ErrNotExist) && slices.Contains(mode, CreateParentDirectories) { if errors.Is(err, fs.ErrNotExist) && slices.Contains(mode, CreateParentDirectories) {
// Create parent directories if they don't exist. // Create parent directories if they don't exist.
err = os.MkdirAll(filepath.Dir(absPath), 0755) err = os.MkdirAll(filepath.Dir(absPath), 0755)
@ -43,7 +52,7 @@ func (w *LocalClient) Write(ctx context.Context, name string, reader io.Reader,
return err return err
} }
// Try again. // Try again.
f, err = os.OpenFile(absPath, flags, 0644) f, err = os.OpenFile(absPath, flags, perm)
} }
if err != nil { if err != nil {

View File

@ -244,6 +244,17 @@ func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string,
// Try to read the file as a regular file. If the file is not found, try to read it as a notebook. // Try to read the file as a regular file. If the file is not found, try to read it as a notebook.
func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) {
// Ensure that the file / notebook exists. We do this check here to avoid reading
// the content of a notebook called `foo` when the user actually wanted
// to read the content of a file called `foo`.
//
// To read the content of a notebook called `foo` in the workspace the user
// should use the name with the extension included like `foo.ipynb` or `foo.sql`.
_, err := w.Stat(ctx, name)
if err != nil {
return nil, err
}
r, err := w.wsfs.Read(ctx, name) r, err := w.wsfs.Read(ctx, name)
// If the file is not found, it might be a notebook. // If the file is not found, it might be a notebook.
@ -276,7 +287,18 @@ func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string
return ReadOnlyError{"delete"} return ReadOnlyError{"delete"}
} }
err := w.wsfs.Delete(ctx, name, mode...) // Ensure that the file / notebook exists. We do this check here to avoid
// deleting the a notebook called `foo` when the user actually wanted to
// delete a file called `foo`.
//
// To delete a notebook called `foo` in the workspace the user should use the
// name with the extension included like `foo.ipynb` or `foo.sql`.
_, err := w.Stat(ctx, name)
if err != nil {
return err
}
err = w.wsfs.Delete(ctx, name, mode...)
// If the file is not found, it might be a notebook. // If the file is not found, it might be a notebook.
if errors.As(err, &FileDoesNotExistError{}) { if errors.As(err, &FileDoesNotExistError{}) {
@ -315,7 +337,24 @@ func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string)
return wsfsFileInfo{ObjectInfo: stat.ObjectInfo}, nil return wsfsFileInfo{ObjectInfo: stat.ObjectInfo}, nil
} }
return info, err if err != nil {
return nil, err
}
// If an object is found and it is a notebook, return a FileDoesNotExistError.
// If a notebook is found by the workspace files client, without having stripped
// the extension, this implies that no file with the same name exists.
//
// This check is done to avoid returning the stat for a notebook called `foo`
// when the user actually wanted to stat a file called `foo`.
//
// To stat the metadata of a notebook called `foo` in the workspace the user
// should use the name with the extension included like `foo.ipynb` or `foo.sql`.
if info.Sys().(workspace.ObjectInfo).ObjectType == workspace.ObjectTypeNotebook {
return nil, FileDoesNotExistError{name}
}
return info, nil
} }
// Note: The import API returns opaque internal errors for namespace clashes // Note: The import API returns opaque internal errors for namespace clashes

View File

@ -3,7 +3,9 @@ package jsonschema
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath"
"regexp" "regexp"
"slices" "slices"
@ -255,7 +257,12 @@ func (schema *Schema) validate() error {
} }
func Load(path string) (*Schema, error) { func Load(path string) (*Schema, error) {
b, err := os.ReadFile(path) dir, file := filepath.Split(path)
return LoadFS(os.DirFS(dir), file)
}
func LoadFS(fsys fs.FS, path string) (*Schema, error) {
b, err := fs.ReadFile(fsys, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package jsonschema package jsonschema
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -305,3 +306,9 @@ func TestValidateSchemaSkippedPropertiesHaveDefaults(t *testing.T) {
err = s.validate() err = s.validate()
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestSchema_LoadFS(t *testing.T) {
fsys := os.DirFS("./testdata/schema-load-int")
_, err := LoadFS(fsys, "schema-valid.json")
assert.NoError(t, err)
}

47
libs/template/builtin.go Normal file
View File

@ -0,0 +1,47 @@
package template
import (
"embed"
"io/fs"
)
//go:embed all:templates
var builtinTemplates embed.FS
// BuiltinTemplate represents a template that is built into the CLI.
type BuiltinTemplate struct {
Name string
FS fs.FS
}
// Builtin returns the list of all built-in templates.
func Builtin() ([]BuiltinTemplate, error) {
templates, err := fs.Sub(builtinTemplates, "templates")
if err != nil {
return nil, err
}
entries, err := fs.ReadDir(templates, ".")
if err != nil {
return nil, err
}
var out []BuiltinTemplate
for _, entry := range entries {
if !entry.IsDir() {
continue
}
templateFS, err := fs.Sub(templates, entry.Name())
if err != nil {
return nil, err
}
out = append(out, BuiltinTemplate{
Name: entry.Name(),
FS: templateFS,
})
}
return out, nil
}

View File

@ -0,0 +1,28 @@
package template
import (
"io/fs"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuiltin(t *testing.T) {
out, err := Builtin()
require.NoError(t, err)
assert.Len(t, out, 3)
// Confirm names.
assert.Equal(t, "dbt-sql", out[0].Name)
assert.Equal(t, "default-python", out[1].Name)
assert.Equal(t, "default-sql", out[2].Name)
// Confirm that the filesystems work.
_, err = fs.Stat(out[0].FS, `template/{{.project_name}}/dbt_project.yml.tmpl`)
assert.NoError(t, err)
_, err = fs.Stat(out[1].FS, `template/{{.project_name}}/tests/main_test.py.tmpl`)
assert.NoError(t, err)
_, err = fs.Stat(out[2].FS, `template/{{.project_name}}/src/orders_daily.sql.tmpl`)
assert.NoError(t, err)
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/jsonschema" "github.com/databricks/cli/libs/jsonschema"
@ -28,9 +29,8 @@ type config struct {
schema *jsonschema.Schema schema *jsonschema.Schema
} }
func newConfig(ctx context.Context, schemaPath string) (*config, error) { func newConfig(ctx context.Context, templateFS fs.FS, schemaPath string) (*config, error) {
// Read config schema schema, err := jsonschema.LoadFS(templateFS, schemaPath)
schema, err := jsonschema.Load(schemaPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,6 +3,8 @@ package template
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path"
"path/filepath" "path/filepath"
"testing" "testing"
"text/template" "text/template"
@ -16,7 +18,7 @@ func TestTemplateConfigAssignValuesFromFile(t *testing.T) {
testDir := "./testdata/config-assign-from-file" testDir := "./testdata/config-assign-from-file"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
err = c.assignValuesFromFile(filepath.Join(testDir, "config.json")) err = c.assignValuesFromFile(filepath.Join(testDir, "config.json"))
@ -32,7 +34,7 @@ func TestTemplateConfigAssignValuesFromFileDoesNotOverwriteExistingConfigs(t *te
testDir := "./testdata/config-assign-from-file" testDir := "./testdata/config-assign-from-file"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
c.values = map[string]any{ c.values = map[string]any{
@ -52,7 +54,7 @@ func TestTemplateConfigAssignValuesFromFileForInvalidIntegerValue(t *testing.T)
testDir := "./testdata/config-assign-from-file-invalid-int" testDir := "./testdata/config-assign-from-file-invalid-int"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
err = c.assignValuesFromFile(filepath.Join(testDir, "config.json")) err = c.assignValuesFromFile(filepath.Join(testDir, "config.json"))
@ -63,7 +65,7 @@ func TestTemplateConfigAssignValuesFromFileFiltersPropertiesNotInTheSchema(t *te
testDir := "./testdata/config-assign-from-file-unknown-property" testDir := "./testdata/config-assign-from-file-unknown-property"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
err = c.assignValuesFromFile(filepath.Join(testDir, "config.json")) err = c.assignValuesFromFile(filepath.Join(testDir, "config.json"))
@ -78,10 +80,10 @@ func TestTemplateConfigAssignValuesFromDefaultValues(t *testing.T) {
testDir := "./testdata/config-assign-from-default-value" testDir := "./testdata/config-assign-from-default-value"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
r, err := newRenderer(ctx, nil, nil, "./testdata/empty/template", "./testdata/empty/library", t.TempDir()) r, err := newRenderer(ctx, nil, nil, os.DirFS("."), "./testdata/empty/template", "./testdata/empty/library")
require.NoError(t, err) require.NoError(t, err)
err = c.assignDefaultValues(r) err = c.assignDefaultValues(r)
@ -97,10 +99,10 @@ func TestTemplateConfigAssignValuesFromTemplatedDefaultValues(t *testing.T) {
testDir := "./testdata/config-assign-from-templated-default-value" testDir := "./testdata/config-assign-from-templated-default-value"
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, filepath.Join(testDir, "schema.json")) c, err := newConfig(ctx, os.DirFS(testDir), "schema.json")
require.NoError(t, err) require.NoError(t, err)
r, err := newRenderer(ctx, nil, nil, filepath.Join(testDir, "template/template"), filepath.Join(testDir, "template/library"), t.TempDir()) r, err := newRenderer(ctx, nil, nil, os.DirFS("."), path.Join(testDir, "template/template"), path.Join(testDir, "template/library"))
require.NoError(t, err) require.NoError(t, err)
// Note: only the string value is templated. // Note: only the string value is templated.
@ -116,7 +118,7 @@ func TestTemplateConfigAssignValuesFromTemplatedDefaultValues(t *testing.T) {
func TestTemplateConfigValidateValuesDefined(t *testing.T) { func TestTemplateConfigValidateValuesDefined(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, "testdata/config-test-schema/test-schema.json") c, err := newConfig(ctx, os.DirFS("testdata/config-test-schema"), "test-schema.json")
require.NoError(t, err) require.NoError(t, err)
c.values = map[string]any{ c.values = map[string]any{
@ -131,7 +133,7 @@ func TestTemplateConfigValidateValuesDefined(t *testing.T) {
func TestTemplateConfigValidateTypeForValidConfig(t *testing.T) { func TestTemplateConfigValidateTypeForValidConfig(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, "testdata/config-test-schema/test-schema.json") c, err := newConfig(ctx, os.DirFS("testdata/config-test-schema"), "test-schema.json")
require.NoError(t, err) require.NoError(t, err)
c.values = map[string]any{ c.values = map[string]any{
@ -147,7 +149,7 @@ func TestTemplateConfigValidateTypeForValidConfig(t *testing.T) {
func TestTemplateConfigValidateTypeForUnknownField(t *testing.T) { func TestTemplateConfigValidateTypeForUnknownField(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, "testdata/config-test-schema/test-schema.json") c, err := newConfig(ctx, os.DirFS("testdata/config-test-schema"), "test-schema.json")
require.NoError(t, err) require.NoError(t, err)
c.values = map[string]any{ c.values = map[string]any{
@ -164,7 +166,7 @@ func TestTemplateConfigValidateTypeForUnknownField(t *testing.T) {
func TestTemplateConfigValidateTypeForInvalidType(t *testing.T) { func TestTemplateConfigValidateTypeForInvalidType(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c, err := newConfig(ctx, "testdata/config-test-schema/test-schema.json") c, err := newConfig(ctx, os.DirFS("testdata/config-test-schema"), "test-schema.json")
require.NoError(t, err) require.NoError(t, err)
c.values = map[string]any{ c.values = map[string]any{
@ -271,7 +273,8 @@ func TestTemplateEnumValidation(t *testing.T) {
} }
func TestTemplateSchemaErrorsWithEmptyDescription(t *testing.T) { func TestTemplateSchemaErrorsWithEmptyDescription(t *testing.T) {
_, err := newConfig(context.Background(), "./testdata/config-test-schema/invalid-test-schema.json") ctx := context.Background()
_, err := newConfig(ctx, os.DirFS("./testdata/config-test-schema"), "invalid-test-schema.json")
assert.EqualError(t, err, "template property property-without-description is missing a description") assert.EqualError(t, err, "template property property-without-description is missing a description")
} }

View File

@ -1,11 +1,10 @@
package template package template
import ( import (
"bytes"
"context" "context"
"io"
"io/fs" "io/fs"
"os" "slices"
"path/filepath"
"github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/filer"
) )
@ -13,89 +12,69 @@ import (
// Interface representing a file to be materialized from a template into a project // Interface representing a file to be materialized from a template into a project
// instance // instance
type file interface { type file interface {
// Destination path for file. This is where the file will be created when // Path of the file relative to the root of the instantiated template.
// PersistToDisk is called. // This is where the file is written to when persisting the template to disk.
DstPath() *destinationPath // Must be slash-separated.
RelPath() string
// Write file to disk at the destination path. // Write file to disk at the destination path.
PersistToDisk() error Write(ctx context.Context, out filer.Filer) error
}
type destinationPath struct { // contents returns the file contents as a byte slice.
// Root path for the project instance. This path uses the system's default // This is used for testing purposes.
// file separator. For example /foo/bar on Unix and C:\foo\bar on windows contents() ([]byte, error)
root string
// Unix like file path relative to the "root" of the instantiated project. Is used to
// evaluate whether the file should be skipped by comparing it to a list of
// skip glob patterns.
relPath string
}
// Absolute path of the file, in the os native format. For example /foo/bar on
// Unix and C:\foo\bar on windows
func (f *destinationPath) absPath() string {
return filepath.Join(f.root, filepath.FromSlash(f.relPath))
} }
type copyFile struct { type copyFile struct {
ctx context.Context
// Permissions bits for the destination file // Permissions bits for the destination file
perm fs.FileMode perm fs.FileMode
dstPath *destinationPath // Destination path for the file.
relPath string
// Filer rooted at template root. Used to read srcPath. // [fs.FS] rooted at template root. Used to read srcPath.
srcFiler filer.Filer srcFS fs.FS
// Relative path from template root for file to be copied. // Relative path from template root for file to be copied.
srcPath string srcPath string
} }
func (f *copyFile) DstPath() *destinationPath { func (f *copyFile) RelPath() string {
return f.dstPath return f.relPath
} }
func (f *copyFile) PersistToDisk() error { func (f *copyFile) Write(ctx context.Context, out filer.Filer) error {
path := f.DstPath().absPath() src, err := f.srcFS.Open(f.srcPath)
err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil { if err != nil {
return err return err
} }
srcFile, err := f.srcFiler.Read(f.ctx, f.srcPath) defer src.Close()
if err != nil { return out.Write(ctx, f.relPath, src, filer.CreateParentDirectories, filer.WriteMode(f.perm))
return err }
}
defer srcFile.Close() func (f *copyFile) contents() ([]byte, error) {
dstFile, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, f.perm) return fs.ReadFile(f.srcFS, f.srcPath)
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
} }
type inMemoryFile struct { type inMemoryFile struct {
dstPath *destinationPath
content []byte
// Permissions bits for the destination file // Permissions bits for the destination file
perm fs.FileMode perm fs.FileMode
// Destination path for the file.
relPath string
// Contents of the file.
content []byte
} }
func (f *inMemoryFile) DstPath() *destinationPath { func (f *inMemoryFile) RelPath() string {
return f.dstPath return f.relPath
} }
func (f *inMemoryFile) PersistToDisk() error { func (f *inMemoryFile) Write(ctx context.Context, out filer.Filer) error {
path := f.DstPath().absPath() return out.Write(ctx, f.relPath, bytes.NewReader(f.content), filer.CreateParentDirectories, filer.WriteMode(f.perm))
}
err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil { func (f *inMemoryFile) contents() ([]byte, error) {
return err return slices.Clone(f.content), nil
}
return os.WriteFile(path, f.content, f.perm)
} }

View File

@ -13,76 +13,51 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func testInMemoryFile(t *testing.T, perm fs.FileMode) { func testInMemoryFile(t *testing.T, ctx context.Context, perm fs.FileMode) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
f := &inMemoryFile{ f := &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "a/b/c",
},
perm: perm, perm: perm,
relPath: "a/b/c",
content: []byte("123"), content: []byte("123"),
} }
err := f.PersistToDisk()
out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = f.Write(ctx, out)
assert.NoError(t, err) assert.NoError(t, err)
assertFileContent(t, filepath.Join(tmpDir, "a/b/c"), "123") assertFileContent(t, filepath.Join(tmpDir, "a/b/c"), "123")
assertFilePermissions(t, filepath.Join(tmpDir, "a/b/c"), perm) assertFilePermissions(t, filepath.Join(tmpDir, "a/b/c"), perm)
} }
func testCopyFile(t *testing.T, perm fs.FileMode) { func testCopyFile(t *testing.T, ctx context.Context, perm fs.FileMode) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
err := os.WriteFile(filepath.Join(tmpDir, "source"), []byte("qwerty"), perm)
templateFiler, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(tmpDir, "source"), []byte("qwerty"), perm)
require.NoError(t, err) require.NoError(t, err)
f := &copyFile{ f := &copyFile{
ctx: context.Background(), perm: perm,
dstPath: &destinationPath{ relPath: "a/b/c",
root: tmpDir, srcFS: os.DirFS(tmpDir),
relPath: "a/b/c", srcPath: "source",
},
perm: perm,
srcPath: "source",
srcFiler: templateFiler,
} }
err = f.PersistToDisk()
out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = f.Write(ctx, out)
assert.NoError(t, err) assert.NoError(t, err)
assertFileContent(t, filepath.Join(tmpDir, "a/b/c"), "qwerty") assertFileContent(t, filepath.Join(tmpDir, "a/b/c"), "qwerty")
assertFilePermissions(t, filepath.Join(tmpDir, "a/b/c"), perm) assertFilePermissions(t, filepath.Join(tmpDir, "a/b/c"), perm)
} }
func TestTemplateFileDestinationPath(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
f := &destinationPath{
root: `a/b/c`,
relPath: "d/e",
}
assert.Equal(t, `a/b/c/d/e`, f.absPath())
}
func TestTemplateFileDestinationPathForWindows(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
f := &destinationPath{
root: `c:\a\b\c`,
relPath: "d/e",
}
assert.Equal(t, `c:\a\b\c\d\e`, f.absPath())
}
func TestTemplateInMemoryFilePersistToDisk(t *testing.T) { func TestTemplateInMemoryFilePersistToDisk(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.SkipNow() t.SkipNow()
} }
testInMemoryFile(t, 0755) ctx := context.Background()
testInMemoryFile(t, ctx, 0755)
} }
func TestTemplateInMemoryFilePersistToDiskForWindows(t *testing.T) { func TestTemplateInMemoryFilePersistToDiskForWindows(t *testing.T) {
@ -91,14 +66,16 @@ func TestTemplateInMemoryFilePersistToDiskForWindows(t *testing.T) {
} }
// we have separate tests for windows because of differences in valid // we have separate tests for windows because of differences in valid
// fs.FileMode values we can use for different operating systems. // fs.FileMode values we can use for different operating systems.
testInMemoryFile(t, 0666) ctx := context.Background()
testInMemoryFile(t, ctx, 0666)
} }
func TestTemplateCopyFilePersistToDisk(t *testing.T) { func TestTemplateCopyFilePersistToDisk(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.SkipNow() t.SkipNow()
} }
testCopyFile(t, 0644) ctx := context.Background()
testCopyFile(t, ctx, 0644)
} }
func TestTemplateCopyFilePersistToDiskForWindows(t *testing.T) { func TestTemplateCopyFilePersistToDiskForWindows(t *testing.T) {
@ -107,5 +84,6 @@ func TestTemplateCopyFilePersistToDiskForWindows(t *testing.T) {
} }
// we have separate tests for windows because of differences in valid // we have separate tests for windows because of differences in valid
// fs.FileMode values we can use for different operating systems. // fs.FileMode values we can use for different operating systems.
testCopyFile(t, 0666) ctx := context.Background()
testCopyFile(t, ctx, 0666)
} }

View File

@ -18,11 +18,10 @@ import (
func TestTemplatePrintStringWithoutProcessing(t *testing.T) { func TestTemplatePrintStringWithoutProcessing(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/print-without-processing/template", "./testdata/print-without-processing/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/print-without-processing/template", "./testdata/print-without-processing/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -35,11 +34,10 @@ func TestTemplatePrintStringWithoutProcessing(t *testing.T) {
func TestTemplateRegexpCompileFunction(t *testing.T) { func TestTemplateRegexpCompileFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/regexp-compile/template", "./testdata/regexp-compile/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/regexp-compile/template", "./testdata/regexp-compile/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -53,11 +51,10 @@ func TestTemplateRegexpCompileFunction(t *testing.T) {
func TestTemplateRandIntFunction(t *testing.T) { func TestTemplateRandIntFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/random-int/template", "./testdata/random-int/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/random-int/template", "./testdata/random-int/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -71,11 +68,10 @@ func TestTemplateRandIntFunction(t *testing.T) {
func TestTemplateUuidFunction(t *testing.T) { func TestTemplateUuidFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/uuid/template", "./testdata/uuid/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/uuid/template", "./testdata/uuid/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -88,11 +84,10 @@ func TestTemplateUuidFunction(t *testing.T) {
func TestTemplateUrlFunction(t *testing.T) { func TestTemplateUrlFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/urlparse-function/template", "./testdata/urlparse-function/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/urlparse-function/template", "./testdata/urlparse-function/library")
require.NoError(t, err) require.NoError(t, err)
@ -105,11 +100,10 @@ func TestTemplateUrlFunction(t *testing.T) {
func TestTemplateMapPairFunction(t *testing.T) { func TestTemplateMapPairFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/map-pair/template", "./testdata/map-pair/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/map-pair/template", "./testdata/map-pair/library")
require.NoError(t, err) require.NoError(t, err)
@ -122,7 +116,6 @@ func TestTemplateMapPairFunction(t *testing.T) {
func TestWorkspaceHost(t *testing.T) { func TestWorkspaceHost(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tmpDir := t.TempDir()
w := &databricks.WorkspaceClient{ w := &databricks.WorkspaceClient{
Config: &workspaceConfig.Config{ Config: &workspaceConfig.Config{
@ -132,7 +125,7 @@ func TestWorkspaceHost(t *testing.T) {
ctx = root.SetWorkspaceClient(ctx, w) ctx = root.SetWorkspaceClient(ctx, w)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/workspace-host/template", "./testdata/map-pair/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/workspace-host/template", "./testdata/map-pair/library")
require.NoError(t, err) require.NoError(t, err)
@ -149,7 +142,6 @@ func TestWorkspaceHostNotConfigured(t *testing.T) {
ctx := context.Background() ctx := context.Background()
cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "template") cmd := cmdio.NewIO(flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "template")
ctx = cmdio.InContext(ctx, cmd) ctx = cmdio.InContext(ctx, cmd)
tmpDir := t.TempDir()
w := &databricks.WorkspaceClient{ w := &databricks.WorkspaceClient{
Config: &workspaceConfig.Config{}, Config: &workspaceConfig.Config{},
@ -157,7 +149,7 @@ func TestWorkspaceHostNotConfigured(t *testing.T) {
ctx = root.SetWorkspaceClient(ctx, w) ctx = root.SetWorkspaceClient(ctx, w)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/workspace-host/template", "./testdata/map-pair/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/workspace-host/template", "./testdata/map-pair/library")
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -2,54 +2,32 @@ package template
import ( import (
"context" "context"
"embed"
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os"
"path"
"path/filepath"
"github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/filer"
) )
const libraryDirName = "library" const libraryDirName = "library"
const templateDirName = "template" const templateDirName = "template"
const schemaFileName = "databricks_template_schema.json" const schemaFileName = "databricks_template_schema.json"
//go:embed all:templates
var builtinTemplates embed.FS
// This function materializes the input templates as a project, using user defined // This function materializes the input templates as a project, using user defined
// configurations. // configurations.
// Parameters: // Parameters:
// //
// ctx: context containing a cmdio object. This is used to prompt the user // ctx: context containing a cmdio object. This is used to prompt the user
// configFilePath: file path containing user defined config values // configFilePath: file path containing user defined config values
// templateRoot: root of the template definition // templateFS: root of the template definition
// outputDir: root of directory where to initialize the template // outputFiler: filer to use for writing the initialized template
func Materialize(ctx context.Context, configFilePath, templateRoot, outputDir string) error { func Materialize(ctx context.Context, configFilePath string, templateFS fs.FS, outputFiler filer.Filer) error {
// Use a temporary directory in case any builtin templates like default-python are used if _, err := fs.Stat(templateFS, schemaFileName); errors.Is(err, fs.ErrNotExist) {
tempDir, err := os.MkdirTemp("", "templates") return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName)
defer os.RemoveAll(tempDir)
if err != nil {
return err
}
templateRoot, err = prepareBuiltinTemplates(templateRoot, tempDir)
if err != nil {
return err
} }
templatePath := filepath.Join(templateRoot, templateDirName) config, err := newConfig(ctx, templateFS, schemaFileName)
libraryPath := filepath.Join(templateRoot, libraryDirName)
schemaPath := filepath.Join(templateRoot, schemaFileName)
helpers := loadHelpers(ctx)
if _, err := os.Stat(schemaPath); errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaPath)
}
config, err := newConfig(ctx, schemaPath)
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +40,8 @@ func Materialize(ctx context.Context, configFilePath, templateRoot, outputDir st
} }
} }
r, err := newRenderer(ctx, config.values, helpers, templatePath, libraryPath, outputDir) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, config.values, helpers, templateFS, templateDirName, libraryDirName)
if err != nil { if err != nil {
return err return err
} }
@ -94,7 +73,7 @@ func Materialize(ctx context.Context, configFilePath, templateRoot, outputDir st
return err return err
} }
err = r.persistToDisk() err = r.persistToDisk(ctx, outputFiler)
if err != nil { if err != nil {
return err return err
} }
@ -111,44 +90,3 @@ func Materialize(ctx context.Context, configFilePath, templateRoot, outputDir st
} }
return nil return nil
} }
// If the given templateRoot matches
func prepareBuiltinTemplates(templateRoot string, tempDir string) (string, error) {
// Check that `templateRoot` is a clean basename, i.e. `some_path` and not `./some_path` or "."
// Return early if that's not the case.
if templateRoot == "." || path.Base(templateRoot) != templateRoot {
return templateRoot, nil
}
_, err := fs.Stat(builtinTemplates, path.Join("templates", templateRoot))
if err != nil {
// The given path doesn't appear to be using out built-in templates
return templateRoot, nil
}
// We have a built-in template with the same name as templateRoot!
// Now we need to make a fully copy of the builtin templates to a real file system
// since template.Parse() doesn't support embed.FS.
err = fs.WalkDir(builtinTemplates, "templates", func(path string, entry fs.DirEntry, err error) error {
if err != nil {
return err
}
targetPath := filepath.Join(tempDir, path)
if entry.IsDir() {
return os.Mkdir(targetPath, 0755)
} else {
content, err := fs.ReadFile(builtinTemplates, path)
if err != nil {
return err
}
return os.WriteFile(targetPath, content, 0644)
}
})
if err != nil {
return "", err
}
return filepath.Join(tempDir, "templates", templateRoot), nil
}

View File

@ -3,7 +3,7 @@ package template
import ( import (
"context" "context"
"fmt" "fmt"
"path/filepath" "os"
"testing" "testing"
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
@ -19,6 +19,6 @@ func TestMaterializeForNonTemplateDirectory(t *testing.T) {
ctx := root.SetWorkspaceClient(context.Background(), w) ctx := root.SetWorkspaceClient(context.Background(), w)
// Try to materialize a non-template directory. // Try to materialize a non-template directory.
err = Materialize(ctx, "", tmpDir, "") err = Materialize(ctx, "", os.DirFS(tmpDir), nil)
assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", filepath.Join(tmpDir, schemaFileName))) assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName))
} }

View File

@ -6,9 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"os"
"path" "path"
"path/filepath"
"regexp" "regexp"
"slices" "slices"
"sort" "sort"
@ -52,32 +50,38 @@ type renderer struct {
// do not match any glob patterns from this list // do not match any glob patterns from this list
skipPatterns []string skipPatterns []string
// Filer rooted at template root. The file tree from this root is walked to // [fs.FS] that holds the template's file tree.
// generate the project srcFS fs.FS
templateFiler filer.Filer
// Root directory for the project instantiated from the template
instanceRoot string
} }
func newRenderer(ctx context.Context, config map[string]any, helpers template.FuncMap, templateRoot, libraryRoot, instanceRoot string) (*renderer, error) { func newRenderer(
ctx context.Context,
config map[string]any,
helpers template.FuncMap,
templateFS fs.FS,
templateDir string,
libraryDir string,
) (*renderer, error) {
// Initialize new template, with helper functions loaded // Initialize new template, with helper functions loaded
tmpl := template.New("").Funcs(helpers) tmpl := template.New("").Funcs(helpers)
// Load user defined associated templates from the library root // Find user-defined templates in the library directory
libraryGlob := filepath.Join(libraryRoot, "*") matches, err := fs.Glob(templateFS, path.Join(libraryDir, "*"))
matches, err := filepath.Glob(libraryGlob)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Parse user-defined templates.
// Note: we do not call [ParseFS] with the glob directly because
// it returns an error if no files match the pattern.
if len(matches) != 0 { if len(matches) != 0 {
tmpl, err = tmpl.ParseFiles(matches...) tmpl, err = tmpl.ParseFS(templateFS, matches...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
templateFiler, err := filer.NewLocalClient(templateRoot) srcFS, err := fs.Sub(templateFS, path.Clean(templateDir))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -85,13 +89,12 @@ func newRenderer(ctx context.Context, config map[string]any, helpers template.Fu
ctx = log.NewContext(ctx, log.GetLogger(ctx).With("action", "initialize-template")) ctx = log.NewContext(ctx, log.GetLogger(ctx).With("action", "initialize-template"))
return &renderer{ return &renderer{
ctx: ctx, ctx: ctx,
config: config, config: config,
baseTemplate: tmpl, baseTemplate: tmpl,
files: make([]file, 0), files: make([]file, 0),
skipPatterns: make([]string, 0), skipPatterns: make([]string, 0),
templateFiler: templateFiler, srcFS: srcFS,
instanceRoot: instanceRoot,
}, nil }, nil
} }
@ -141,7 +144,7 @@ func (r *renderer) executeTemplate(templateDefinition string) (string, error) {
func (r *renderer) computeFile(relPathTemplate string) (file, error) { func (r *renderer) computeFile(relPathTemplate string) (file, error) {
// read file permissions // read file permissions
info, err := r.templateFiler.Stat(r.ctx, relPathTemplate) info, err := fs.Stat(r.srcFS, relPathTemplate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -157,14 +160,10 @@ func (r *renderer) computeFile(relPathTemplate string) (file, error) {
// over as is, without treating it as a template // over as is, without treating it as a template
if !strings.HasSuffix(relPathTemplate, templateExtension) { if !strings.HasSuffix(relPathTemplate, templateExtension) {
return &copyFile{ return &copyFile{
dstPath: &destinationPath{ perm: perm,
root: r.instanceRoot, relPath: relPath,
relPath: relPath, srcFS: r.srcFS,
}, srcPath: relPathTemplate,
perm: perm,
ctx: r.ctx,
srcPath: relPathTemplate,
srcFiler: r.templateFiler,
}, nil }, nil
} else { } else {
// Trim the .tmpl suffix from file name, if specified in the template // Trim the .tmpl suffix from file name, if specified in the template
@ -173,7 +172,7 @@ func (r *renderer) computeFile(relPathTemplate string) (file, error) {
} }
// read template file's content // read template file's content
templateReader, err := r.templateFiler.Read(r.ctx, relPathTemplate) templateReader, err := r.srcFS.Open(relPathTemplate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -194,11 +193,8 @@ func (r *renderer) computeFile(relPathTemplate string) (file, error) {
} }
return &inMemoryFile{ return &inMemoryFile{
dstPath: &destinationPath{
root: r.instanceRoot,
relPath: relPath,
},
perm: perm, perm: perm,
relPath: relPath,
content: []byte(content), content: []byte(content),
}, nil }, nil
} }
@ -263,7 +259,7 @@ func (r *renderer) walk() error {
// //
// 2. For directories: They are appended to a slice, which acts as a queue // 2. For directories: They are appended to a slice, which acts as a queue
// allowing BFS traversal of the template file tree // allowing BFS traversal of the template file tree
entries, err := r.templateFiler.ReadDir(r.ctx, currentDirectory) entries, err := fs.ReadDir(r.srcFS, currentDirectory)
if err != nil { if err != nil {
return err return err
} }
@ -283,7 +279,7 @@ func (r *renderer) walk() error {
if err != nil { if err != nil {
return err return err
} }
logger.Infof(r.ctx, "added file to list of possible project files: %s", f.DstPath().relPath) logger.Infof(r.ctx, "added file to list of possible project files: %s", f.RelPath())
r.files = append(r.files, f) r.files = append(r.files, f)
} }
@ -291,17 +287,17 @@ func (r *renderer) walk() error {
return nil return nil
} }
func (r *renderer) persistToDisk() error { func (r *renderer) persistToDisk(ctx context.Context, out filer.Filer) error {
// Accumulate files which we will persist, skipping files whose path matches // Accumulate files which we will persist, skipping files whose path matches
// any of the skip patterns // any of the skip patterns
filesToPersist := make([]file, 0) filesToPersist := make([]file, 0)
for _, file := range r.files { for _, file := range r.files {
match, err := isSkipped(file.DstPath().relPath, r.skipPatterns) match, err := isSkipped(file.RelPath(), r.skipPatterns)
if err != nil { if err != nil {
return err return err
} }
if match { if match {
log.Infof(r.ctx, "skipping file: %s", file.DstPath()) log.Infof(r.ctx, "skipping file: %s", file.RelPath())
continue continue
} }
filesToPersist = append(filesToPersist, file) filesToPersist = append(filesToPersist, file)
@ -309,8 +305,8 @@ func (r *renderer) persistToDisk() error {
// Assert no conflicting files exist // Assert no conflicting files exist
for _, file := range filesToPersist { for _, file := range filesToPersist {
path := file.DstPath().absPath() path := file.RelPath()
_, err := os.Stat(path) _, err := out.Stat(ctx, path)
if err == nil { if err == nil {
return fmt.Errorf("failed to initialize template, one or more files already exist: %s", path) return fmt.Errorf("failed to initialize template, one or more files already exist: %s", path)
} }
@ -321,7 +317,7 @@ func (r *renderer) persistToDisk() error {
// Persist files to disk // Persist files to disk
for _, file := range filesToPersist { for _, file := range filesToPersist {
err := file.PersistToDisk() err := file.Write(ctx, out)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,9 +3,9 @@ package template
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
@ -18,6 +18,7 @@ import (
"github.com/databricks/cli/cmd/root" "github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/tags" "github.com/databricks/cli/libs/tags"
"github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go"
workspaceConfig "github.com/databricks/databricks-sdk-go/config" workspaceConfig "github.com/databricks/databricks-sdk-go/config"
@ -41,9 +42,8 @@ func assertFilePermissions(t *testing.T, path string, perm fs.FileMode) {
func assertBuiltinTemplateValid(t *testing.T, template string, settings map[string]any, target string, isServicePrincipal bool, build bool, tempDir string) { func assertBuiltinTemplateValid(t *testing.T, template string, settings map[string]any, target string, isServicePrincipal bool, build bool, tempDir string) {
ctx := context.Background() ctx := context.Background()
templatePath, err := prepareBuiltinTemplates(template, tempDir) templateFS, err := fs.Sub(builtinTemplates, path.Join("templates", template))
require.NoError(t, err) require.NoError(t, err)
libraryPath := filepath.Join(templatePath, "library")
w := &databricks.WorkspaceClient{ w := &databricks.WorkspaceClient{
Config: &workspaceConfig.Config{Host: "https://myhost.com"}, Config: &workspaceConfig.Config{Host: "https://myhost.com"},
@ -58,16 +58,18 @@ func assertBuiltinTemplateValid(t *testing.T, template string, settings map[stri
ctx = root.SetWorkspaceClient(ctx, w) ctx = root.SetWorkspaceClient(ctx, w)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
renderer, err := newRenderer(ctx, settings, helpers, templatePath, libraryPath, tempDir) renderer, err := newRenderer(ctx, settings, helpers, templateFS, templateDirName, libraryDirName)
require.NoError(t, err) require.NoError(t, err)
// Evaluate template // Evaluate template
err = renderer.walk() err = renderer.walk()
require.NoError(t, err) require.NoError(t, err)
err = renderer.persistToDisk() out, err := filer.NewLocalClient(tempDir)
require.NoError(t, err)
err = renderer.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
b, err := bundle.Load(ctx, filepath.Join(tempDir, "template", "my_project")) b, err := bundle.Load(ctx, filepath.Join(tempDir, "my_project"))
require.NoError(t, err) require.NoError(t, err)
diags := bundle.Apply(ctx, b, phases.LoadNamedTarget(target)) diags := bundle.Apply(ctx, b, phases.LoadNamedTarget(target))
require.NoError(t, diags.Error()) require.NoError(t, diags.Error())
@ -96,18 +98,6 @@ func assertBuiltinTemplateValid(t *testing.T, template string, settings map[stri
} }
} }
func TestPrepareBuiltInTemplatesWithRelativePaths(t *testing.T) {
// CWD should not be resolved as a built in template
dir, err := prepareBuiltinTemplates(".", t.TempDir())
assert.NoError(t, err)
assert.Equal(t, ".", dir)
// relative path should not be resolved as a built in template
dir, err = prepareBuiltinTemplates("./default-python", t.TempDir())
assert.NoError(t, err)
assert.Equal(t, "./default-python", dir)
}
func TestBuiltinPythonTemplateValid(t *testing.T) { func TestBuiltinPythonTemplateValid(t *testing.T) {
// Test option combinations // Test option combinations
options := []string{"yes", "no"} options := []string{"yes", "no"}
@ -194,13 +184,14 @@ func TestRendererWithAssociatedTemplateInLibrary(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/email/template", "./testdata/email/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/email/template", "./testdata/email/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
require.NoError(t, err) require.NoError(t, err)
out, err := filer.NewLocalClient(tmpDir)
err = r.persistToDisk() require.NoError(t, err)
err = r.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
b, err := os.ReadFile(filepath.Join(tmpDir, "my_email")) b, err := os.ReadFile(filepath.Join(tmpDir, "my_email"))
@ -325,45 +316,34 @@ func TestRendererPersistToDisk(t *testing.T) {
r := &renderer{ r := &renderer{
ctx: ctx, ctx: ctx,
instanceRoot: tmpDir,
skipPatterns: []string{"a/b/c", "mn*"}, skipPatterns: []string{"a/b/c", "mn*"},
files: []file{ files: []file{
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "a/b/c",
},
perm: 0444, perm: 0444,
relPath: "a/b/c",
content: nil, content: nil,
}, },
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "mno",
},
perm: 0444, perm: 0444,
relPath: "mno",
content: nil, content: nil,
}, },
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "a/b/d",
},
perm: 0444, perm: 0444,
relPath: "a/b/d",
content: []byte("123"), content: []byte("123"),
}, },
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "mmnn",
},
perm: 0444, perm: 0444,
relPath: "mmnn",
content: []byte("456"), content: []byte("456"),
}, },
}, },
} }
err := r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = r.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
assert.NoFileExists(t, filepath.Join(tmpDir, "a", "b", "c")) assert.NoFileExists(t, filepath.Join(tmpDir, "a", "b", "c"))
@ -378,10 +358,9 @@ func TestRendererPersistToDisk(t *testing.T) {
func TestRendererWalk(t *testing.T) { func TestRendererWalk(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/walk/template", "./testdata/walk/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/walk/template", "./testdata/walk/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -389,21 +368,12 @@ func TestRendererWalk(t *testing.T) {
getContent := func(r *renderer, path string) string { getContent := func(r *renderer, path string) string {
for _, f := range r.files { for _, f := range r.files {
if f.DstPath().relPath != path { if f.RelPath() != path {
continue continue
} }
switch v := f.(type) { b, err := f.contents()
case *inMemoryFile: require.NoError(t, err)
return strings.Trim(string(v.content), "\r\n") return strings.Trim(string(b), "\r\n")
case *copyFile:
r, err := r.templateFiler.Read(context.Background(), v.srcPath)
require.NoError(t, err)
b, err := io.ReadAll(r)
require.NoError(t, err)
return strings.Trim(string(b), "\r\n")
default:
require.FailNow(t, "execution should not reach here")
}
} }
require.FailNow(t, "file is absent: "+path) require.FailNow(t, "file is absent: "+path)
return "" return ""
@ -419,10 +389,9 @@ func TestRendererWalk(t *testing.T) {
func TestRendererFailFunction(t *testing.T) { func TestRendererFailFunction(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/fail/template", "./testdata/fail/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/fail/template", "./testdata/fail/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -432,10 +401,9 @@ func TestRendererFailFunction(t *testing.T) {
func TestRendererSkipsDirsEagerly(t *testing.T) { func TestRendererSkipsDirsEagerly(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/skip-dir-eagerly/template", "./testdata/skip-dir-eagerly/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/skip-dir-eagerly/template", "./testdata/skip-dir-eagerly/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -452,7 +420,7 @@ func TestRendererSkipAllFilesInCurrentDirectory(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/skip-all-files-in-cwd/template", "./testdata/skip-all-files-in-cwd/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/skip-all-files-in-cwd/template", "./testdata/skip-all-files-in-cwd/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -460,7 +428,9 @@ func TestRendererSkipAllFilesInCurrentDirectory(t *testing.T) {
// All 3 files are executed and have in memory representations // All 3 files are executed and have in memory representations
require.Len(t, r.files, 3) require.Len(t, r.files, 3)
err = r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = r.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
entries, err := os.ReadDir(tmpDir) entries, err := os.ReadDir(tmpDir)
@ -472,10 +442,9 @@ func TestRendererSkipAllFilesInCurrentDirectory(t *testing.T) {
func TestRendererSkipPatternsAreRelativeToFileDirectory(t *testing.T) { func TestRendererSkipPatternsAreRelativeToFileDirectory(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/skip-is-relative/template", "./testdata/skip-is-relative/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/skip-is-relative/template", "./testdata/skip-is-relative/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -493,7 +462,7 @@ func TestRendererSkip(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/skip/template", "./testdata/skip/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/skip/template", "./testdata/skip/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -502,7 +471,9 @@ func TestRendererSkip(t *testing.T) {
// This is because "dir2/*" matches the files in dir2, but not dir2 itself // This is because "dir2/*" matches the files in dir2, but not dir2 itself
assert.Len(t, r.files, 6) assert.Len(t, r.files, 6)
err = r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = r.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(tmpDir, "file1")) assert.FileExists(t, filepath.Join(tmpDir, "file1"))
@ -520,12 +491,11 @@ func TestRendererReadsPermissionsBits(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.SkipNow() t.SkipNow()
} }
tmpDir := t.TempDir()
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/executable-bit-read/template", "./testdata/executable-bit-read/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/executable-bit-read/template", "./testdata/executable-bit-read/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -533,7 +503,7 @@ func TestRendererReadsPermissionsBits(t *testing.T) {
getPermissions := func(r *renderer, path string) fs.FileMode { getPermissions := func(r *renderer, path string) fs.FileMode {
for _, f := range r.files { for _, f := range r.files {
if f.DstPath().relPath != path { if f.RelPath() != path {
continue continue
} }
switch v := f.(type) { switch v := f.(type) {
@ -556,6 +526,7 @@ func TestRendererReadsPermissionsBits(t *testing.T) {
func TestRendererErrorOnConflictingFile(t *testing.T) { func TestRendererErrorOnConflictingFile(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
ctx := context.Background()
f, err := os.Create(filepath.Join(tmpDir, "a")) f, err := os.Create(filepath.Join(tmpDir, "a"))
require.NoError(t, err) require.NoError(t, err)
@ -566,17 +537,16 @@ func TestRendererErrorOnConflictingFile(t *testing.T) {
skipPatterns: []string{}, skipPatterns: []string{},
files: []file{ files: []file{
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "a",
},
perm: 0444, perm: 0444,
relPath: "a",
content: []byte("123"), content: []byte("123"),
}, },
}, },
} }
err = r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
assert.EqualError(t, err, fmt.Sprintf("failed to initialize template, one or more files already exist: %s", filepath.Join(tmpDir, "a"))) require.NoError(t, err)
err = r.persistToDisk(ctx, out)
assert.EqualError(t, err, fmt.Sprintf("failed to initialize template, one or more files already exist: %s", "a"))
} }
func TestRendererNoErrorOnConflictingFileIfSkipped(t *testing.T) { func TestRendererNoErrorOnConflictingFileIfSkipped(t *testing.T) {
@ -593,16 +563,15 @@ func TestRendererNoErrorOnConflictingFileIfSkipped(t *testing.T) {
skipPatterns: []string{"a"}, skipPatterns: []string{"a"},
files: []file{ files: []file{
&inMemoryFile{ &inMemoryFile{
dstPath: &destinationPath{
root: tmpDir,
relPath: "a",
},
perm: 0444, perm: 0444,
relPath: "a",
content: []byte("123"), content: []byte("123"),
}, },
}, },
} }
err = r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = r.persistToDisk(ctx, out)
// No error is returned even though a conflicting file exists. This is because // No error is returned even though a conflicting file exists. This is because
// the generated file is being skipped // the generated file is being skipped
assert.NoError(t, err) assert.NoError(t, err)
@ -612,10 +581,9 @@ func TestRendererNoErrorOnConflictingFileIfSkipped(t *testing.T) {
func TestRendererNonTemplatesAreCreatedAsCopyFiles(t *testing.T) { func TestRendererNonTemplatesAreCreatedAsCopyFiles(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = root.SetWorkspaceClient(ctx, nil) ctx = root.SetWorkspaceClient(ctx, nil)
tmpDir := t.TempDir()
helpers := loadHelpers(ctx) helpers := loadHelpers(ctx)
r, err := newRenderer(ctx, nil, helpers, "./testdata/copy-file-walk/template", "./testdata/copy-file-walk/library", tmpDir) r, err := newRenderer(ctx, nil, helpers, os.DirFS("."), "./testdata/copy-file-walk/template", "./testdata/copy-file-walk/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -623,7 +591,7 @@ func TestRendererNonTemplatesAreCreatedAsCopyFiles(t *testing.T) {
assert.Len(t, r.files, 1) assert.Len(t, r.files, 1)
assert.Equal(t, r.files[0].(*copyFile).srcPath, "not-a-template") assert.Equal(t, r.files[0].(*copyFile).srcPath, "not-a-template")
assert.Equal(t, r.files[0].DstPath().absPath(), filepath.Join(tmpDir, "not-a-template")) assert.Equal(t, r.files[0].RelPath(), "not-a-template")
} }
func TestRendererFileTreeRendering(t *testing.T) { func TestRendererFileTreeRendering(t *testing.T) {
@ -635,7 +603,7 @@ func TestRendererFileTreeRendering(t *testing.T) {
r, err := newRenderer(ctx, map[string]any{ r, err := newRenderer(ctx, map[string]any{
"dir_name": "my_directory", "dir_name": "my_directory",
"file_name": "my_file", "file_name": "my_file",
}, helpers, "./testdata/file-tree-rendering/template", "./testdata/file-tree-rendering/library", tmpDir) }, helpers, os.DirFS("."), "./testdata/file-tree-rendering/template", "./testdata/file-tree-rendering/library")
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -643,9 +611,11 @@ func TestRendererFileTreeRendering(t *testing.T) {
// Assert in memory representation is created. // Assert in memory representation is created.
assert.Len(t, r.files, 1) assert.Len(t, r.files, 1)
assert.Equal(t, r.files[0].DstPath().absPath(), filepath.Join(tmpDir, "my_directory", "my_file")) assert.Equal(t, r.files[0].RelPath(), "my_directory/my_file")
err = r.persistToDisk() out, err := filer.NewLocalClient(tmpDir)
require.NoError(t, err)
err = r.persistToDisk(ctx, out)
require.NoError(t, err) require.NoError(t, err)
// Assert files and directories are correctly materialized. // Assert files and directories are correctly materialized.
@ -667,8 +637,7 @@ func TestRendererSubTemplateInPath(t *testing.T) {
// https://learn.microsoft.com/en-us/windows/win32/fileio/naming-a-file. // https://learn.microsoft.com/en-us/windows/win32/fileio/naming-a-file.
testutil.Touch(t, filepath.Join(templateDir, "template/{{template `dir_name`}}/{{template `file_name`}}")) testutil.Touch(t, filepath.Join(templateDir, "template/{{template `dir_name`}}/{{template `file_name`}}"))
tmpDir := t.TempDir() r, err := newRenderer(ctx, nil, nil, os.DirFS(templateDir), "template", "library")
r, err := newRenderer(ctx, nil, nil, filepath.Join(templateDir, "template"), filepath.Join(templateDir, "library"), tmpDir)
require.NoError(t, err) require.NoError(t, err)
err = r.walk() err = r.walk()
@ -676,7 +645,6 @@ func TestRendererSubTemplateInPath(t *testing.T) {
if assert.Len(t, r.files, 2) { if assert.Len(t, r.files, 2) {
f := r.files[1] f := r.files[1]
assert.Equal(t, filepath.Join(tmpDir, "my_directory", "my_file"), f.DstPath().absPath()) assert.Equal(t, "my_directory/my_file", f.RelPath())
assert.Equal(t, "my_directory/my_file", f.DstPath().relPath)
} }
} }