Merge remote-tracking branch 'origin' into async-logger-clean

This commit is contained in:
Shreyas Goenka 2025-02-18 16:04:50 +01:00
commit 8a0b359f8f
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
35 changed files with 525 additions and 567 deletions

View File

@ -1,5 +1,46 @@
# Version changelog # Version changelog
## [Release] Release v0.241.2
This is a bugfix release to address an issue where jobs with tasks with a
libraries section with PyPI packages could not be deployed.
Bundles:
* Revert changes related to basename check for local libraries ([#2345](https://github.com/databricks/cli/pull/2345)).
## [Release] Release v0.241.1
Bundles:
* Fix for regression deploying resources with PyPi and Maven library types ([#2341](https://github.com/databricks/cli/pull/2341)).
## [Release] Release v0.241.0
Bundles:
* Added support to generate Git based jobs ([#2304](https://github.com/databricks/cli/pull/2304)).
* Added support for run_as in pipelines ([#2287](https://github.com/databricks/cli/pull/2287)).
* Raise an error when there are multiple local libraries with the same basename used ([#2297](https://github.com/databricks/cli/pull/2297)).
* Fix env variable for AzureCli local config ([#2248](https://github.com/databricks/cli/pull/2248)).
* Accept JSON files in includes section ([#2265](https://github.com/databricks/cli/pull/2265)).
* Always print warnings and errors; clean up format ([#2213](https://github.com/databricks/cli/pull/2213))
API Changes:
* Added `databricks account budget-policy` command group.
* Added `databricks lakeview-embedded` command group.
* Added `databricks query-execution` command group.
* Added `databricks account enable-ip-access-lists` command group.
* Added `databricks redash-config` command group.
OpenAPI commit c72c58f97b950fcb924a90ef164bcb10cfcd5ece (2025-02-03)
Dependency updates:
* Upgrade to TF provider 1.65.1 ([#2328](https://github.com/databricks/cli/pull/2328)).
* Bump github.com/hashicorp/terraform-exec from 0.21.0 to 0.22.0 ([#2237](https://github.com/databricks/cli/pull/2237)).
* Bump github.com/spf13/pflag from 1.0.5 to 1.0.6 ([#2281](https://github.com/databricks/cli/pull/2281)).
* Bump github.com/databricks/databricks-sdk-go from 0.56.1 to 0.57.0 ([#2321](https://github.com/databricks/cli/pull/2321)).
* Bump golang.org/x/oauth2 from 0.25.0 to 0.26.0 ([#2322](https://github.com/databricks/cli/pull/2322)).
* Bump golang.org/x/term from 0.28.0 to 0.29.0 ([#2325](https://github.com/databricks/cli/pull/2325)).
* Bump golang.org/x/text from 0.21.0 to 0.22.0 ([#2323](https://github.com/databricks/cli/pull/2323)).
* Bump golang.org/x/mod from 0.22.0 to 0.23.0 ([#2324](https://github.com/databricks/cli/pull/2324)).
## [Release] Release v0.240.0 ## [Release] Release v0.240.0
Bundles: Bundles:

View File

@ -7,7 +7,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@ -58,6 +57,8 @@ const (
CleanupScript = "script.cleanup" CleanupScript = "script.cleanup"
PrepareScript = "script.prepare" PrepareScript = "script.prepare"
MaxFileSize = 100_000 MaxFileSize = 100_000
// Filename to save replacements to (used by diff.py)
ReplsFile = "repls.json"
) )
var Scripts = map[string]bool{ var Scripts = map[string]bool{
@ -66,6 +67,10 @@ var Scripts = map[string]bool{
PrepareScript: true, PrepareScript: true,
} }
var Ignored = map[string]bool{
ReplsFile: true,
}
func TestAccept(t *testing.T) { func TestAccept(t *testing.T) {
testAccept(t, InprocessMode, SingleTest) testAccept(t, InprocessMode, SingleTest)
} }
@ -153,6 +158,8 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int {
testdiff.PrepareReplacementSdkVersion(t, &repls) testdiff.PrepareReplacementSdkVersion(t, &repls)
testdiff.PrepareReplacementsGoVersion(t, &repls) testdiff.PrepareReplacementsGoVersion(t, &repls)
repls.SetPath(cwd, "[TESTROOT]")
repls.Repls = append(repls.Repls, testdiff.Replacement{Old: regexp.MustCompile("dbapi[0-9a-f]+"), New: "[DATABRICKS_TOKEN]"}) repls.Repls = append(repls.Repls, testdiff.Replacement{Old: regexp.MustCompile("dbapi[0-9a-f]+"), New: "[DATABRICKS_TOKEN]"})
testDirs := getTests(t) testDirs := getTests(t)
@ -259,16 +266,16 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
server.RecordRequests = config.RecordRequests server.RecordRequests = config.RecordRequests
server.IncludeRequestHeaders = config.IncludeRequestHeaders server.IncludeRequestHeaders = config.IncludeRequestHeaders
// We want later stubs takes precedence, because then leaf configs take precedence over parent directory configs
// In gorilla/mux earlier handlers take precedence, so we need to reverse the order
slices.Reverse(config.Server)
for _, stub := range config.Server { for _, stub := range config.Server {
require.NotEmpty(t, stub.Pattern) require.NotEmpty(t, stub.Pattern)
items := strings.Split(stub.Pattern, " ") items := strings.Split(stub.Pattern, " ")
require.Len(t, items, 2) require.Len(t, items, 2)
server.Handle(items[0], items[1], func(fakeWorkspace *testserver.FakeWorkspace, req *http.Request) (any, int) { server.Handle(items[0], items[1], func(req testserver.Request) any {
statusCode := http.StatusOK return stub.Response
if stub.Response.StatusCode != 0 {
statusCode = stub.Response.StatusCode
}
return stub.Response.Body, statusCode
}) })
} }
@ -311,6 +318,11 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
// User replacements come last: // User replacements come last:
repls.Repls = append(repls.Repls, config.Repls...) repls.Repls = append(repls.Repls, config.Repls...)
// Save replacements to temp test directory so that it can be read by diff.py
replsJson, err := json.MarshalIndent(repls.Repls, "", " ")
require.NoError(t, err)
testutil.WriteFile(t, filepath.Join(tmpDir, ReplsFile), string(replsJson))
if coverDir != "" { if coverDir != "" {
// Creating individual coverage directory for each test, because writing to the same one // Creating individual coverage directory for each test, because writing to the same one
// results in sporadic failures like this one (only if tests are running in parallel): // results in sporadic failures like this one (only if tests are running in parallel):
@ -321,6 +333,10 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
cmd.Env = append(cmd.Env, "GOCOVERDIR="+coverDir) cmd.Env = append(cmd.Env, "GOCOVERDIR="+coverDir)
} }
absDir, err := filepath.Abs(dir)
require.NoError(t, err)
cmd.Env = append(cmd.Env, "TESTDIR="+absDir)
// Write combined output to a file // Write combined output to a file
out, err := os.Create(filepath.Join(tmpDir, "output.txt")) out, err := os.Create(filepath.Join(tmpDir, "output.txt"))
require.NoError(t, err) require.NoError(t, err)
@ -369,6 +385,9 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont
if _, ok := outputs[relPath]; ok { if _, ok := outputs[relPath]; ok {
continue continue
} }
if _, ok := Ignored[relPath]; ok {
continue
}
unexpected = append(unexpected, relPath) unexpected = append(unexpected, relPath)
if strings.HasPrefix(relPath, "out") { if strings.HasPrefix(relPath, "out") {
// We have a new file starting with "out" // We have a new file starting with "out"
@ -403,8 +422,7 @@ func doComparison(t *testing.T, repls testdiff.ReplacementsContext, dirRef, dirN
// The test did not produce an expected output file. // The test did not produce an expected output file.
if okRef && !okNew { if okRef && !okNew {
t.Errorf("Missing output file: %s\npathRef: %s\npathNew: %s", relPath, pathRef, pathNew) t.Errorf("Missing output file: %s", relPath)
testdiff.AssertEqualTexts(t, pathRef, pathNew, valueRef, valueNew)
if testdiff.OverwriteMode { if testdiff.OverwriteMode {
t.Logf("Removing output file: %s", relPath) t.Logf("Removing output file: %s", relPath)
require.NoError(t, os.Remove(pathRef)) require.NoError(t, os.Remove(pathRef))

56
acceptance/bin/diff.py Executable file
View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""This script implements "diff -r -U2 dir1 dir2" but applies replacements first"""
import sys
import difflib
import json
import re
from pathlib import Path
def replaceAll(patterns, s):
for comp, new in patterns:
s = comp.sub(new, s)
return s
def main():
d1, d2 = sys.argv[1:]
d1, d2 = Path(d1), Path(d2)
with open("repls.json") as f:
repls = json.load(f)
patterns = []
for r in repls:
try:
c = re.compile(r["Old"])
patterns.append((c, r["New"]))
except re.error as e:
print(f"Regex error for pattern {r}: {e}", file=sys.stderr)
files1 = [str(p.relative_to(d1)) for p in d1.rglob("*") if p.is_file()]
files2 = [str(p.relative_to(d2)) for p in d2.rglob("*") if p.is_file()]
set1 = set(files1)
set2 = set(files2)
for f in sorted(set1 | set2):
p1 = d1 / f
p2 = d2 / f
if f not in set2:
print(f"Only in {d1}: {f}")
elif f not in set1:
print(f"Only in {d2}: {f}")
else:
a = [replaceAll(patterns, x) for x in p1.read_text().splitlines(True)]
b = [replaceAll(patterns, x) for x in p2.read_text().splitlines(True)]
if a != b:
p1_str = p1.as_posix()
p2_str = p2.as_posix()
for line in difflib.unified_diff(a, b, p1_str, p2_str, "", "", 2):
print(line, end="")
if __name__ == "__main__":
main()

View File

@ -1,50 +0,0 @@
bundle:
name: same_name_libraries
variables:
cluster:
default:
spark_version: 15.4.x-scala2.12
node_type_id: i3.xlarge
data_security_mode: SINGLE_USER
num_workers: 0
spark_conf:
spark.master: "local[*, 4]"
spark.databricks.cluster.profile: singleNode
custom_tags:
ResourceClass: SingleNode
artifacts:
whl1:
type: whl
path: ./whl1
whl2:
type: whl
path: ./whl2
resources:
jobs:
test:
name: "test"
tasks:
- task_key: task1
new_cluster: ${var.cluster}
python_wheel_task:
entry_point: main
package_name: my_default_python
libraries:
- whl: ./whl1/dist/*.whl
- task_key: task2
new_cluster: ${var.cluster}
python_wheel_task:
entry_point: main
package_name: my_default_python
libraries:
- whl: ./whl2/dist/*.whl
- task_key: task3
new_cluster: ${var.cluster}
python_wheel_task:
entry_point: main
package_name: my_default_python
libraries:
- whl: ./whl1/dist/*.whl

View File

@ -1,14 +0,0 @@
>>> errcode [CLI] bundle deploy
Building whl1...
Building whl2...
Error: Duplicate local library name my_default_python-0.0.1-py3-none-any.whl
at resources.jobs.test.tasks[0].libraries[0].whl
resources.jobs.test.tasks[1].libraries[0].whl
in databricks.yml:36:15
databricks.yml:43:15
Local library names must be unique
Exit code: 1

View File

@ -1,2 +0,0 @@
trace errcode $CLI bundle deploy
rm -rf whl1 whl2

View File

@ -1,36 +0,0 @@
"""
setup.py configuration script describing how to build and package this project.
This file is primarily used by the setuptools library and typically should not
be executed directly. See README.md for how to deploy, test, and run
the my_default_python project.
"""
from setuptools import setup, find_packages
import sys
sys.path.append("./src")
import my_default_python
setup(
name="my_default_python",
version=my_default_python.__version__,
url="https://databricks.com",
author="[USERNAME]",
description="wheel file based on my_default_python/src",
packages=find_packages(where="./src"),
package_dir={"": "src"},
entry_points={
"packages": [
"main=my_default_python.main:main",
],
},
install_requires=[
# Dependencies in case the output wheel file is used as a library dependency.
# For defining dependencies, when this package is used in Databricks, see:
# https://docs.databricks.com/dev-tools/bundles/library-dependencies.html
"setuptools"
],
)

View File

@ -1,36 +0,0 @@
"""
setup.py configuration script describing how to build and package this project.
This file is primarily used by the setuptools library and typically should not
be executed directly. See README.md for how to deploy, test, and run
the my_default_python project.
"""
from setuptools import setup, find_packages
import sys
sys.path.append("./src")
import my_default_python
setup(
name="my_default_python",
version=my_default_python.__version__,
url="https://databricks.com",
author="[USERNAME]",
description="wheel file based on my_default_python/src",
packages=find_packages(where="./src"),
package_dir={"": "src"},
entry_points={
"packages": [
"main=my_default_python.main:main",
],
},
install_requires=[
# Dependencies in case the output wheel file is used as a library dependency.
# For defining dependencies, when this package is used in Databricks, see:
# https://docs.databricks.com/dev-tools/bundles/library-dependencies.html
"setuptools"
],
)

View File

@ -9,7 +9,7 @@
10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:folder_permissions 10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:folder_permissions
10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:validate_sync_patterns 10:07:59 Debug: ApplyReadOnly pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:validate_sync_patterns
10:07:59 Debug: Path /Workspace/Users/[USERNAME]/.bundle/debug/default/files has type directory (ID: 0) pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync 10:07:59 Debug: Path /Workspace/Users/[USERNAME]/.bundle/debug/default/files has type directory (ID: 0) pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync
10:07:59 Debug: non-retriable error: pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true 10:07:59 Debug: non-retriable error: Workspace path not found pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< {} pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < HTTP/0.0 000 OK pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< {} pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < } pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true
< } pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true < } pid=12345 mutator=validate mutator (read-only)=parallel mutator (read-only)=validate:files_to_sync sdk=true

View File

@ -79,11 +79,12 @@
10:07:59 Debug: Apply pid=12345 mutator=validate 10:07:59 Debug: Apply pid=12345 mutator=validate
10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files 10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files
< HTTP/1.1 404 Not Found < HTTP/1.1 404 Not Found
< {
< "message": "Workspace path not found"
10:07:59 Debug: POST /api/2.0/workspace/mkdirs 10:07:59 Debug: POST /api/2.0/workspace/mkdirs
> { > {
> "path": "/Workspace/Users/[USERNAME]/.bundle/debug/default/files" > "path": "/Workspace/Users/[USERNAME]/.bundle/debug/default/files"
> } > }
< HTTP/1.1 200 OK
10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files 10:07:59 Debug: GET /api/2.0/workspace/get-status?path=/Workspace/Users/[USERNAME]/.bundle/debug/default/files
< HTTP/1.1 200 OK < HTTP/1.1 200 OK
< { < {

View File

@ -37,10 +37,12 @@ The 'my_default_python' project was generated by using the default-python templa
``` ```
$ databricks bundle run $ databricks bundle run
``` ```
6. Optionally, install the Databricks extension for Visual Studio code for local development from
6. Optionally, install developer tools such as the Databricks extension for Visual Studio Code from https://docs.databricks.com/dev-tools/vscode-ext.html. It can configure your
https://docs.databricks.com/dev-tools/vscode-ext.html. Or read the "getting started" documentation for virtual environment and setup Databricks Connect for running unit tests locally.
**Databricks Connect** for instructions on running the included Python code from a different IDE. When not using these tools, consult your development environment's documentation
and/or the documentation for Databricks Connect for manually setting up your environment
(https://docs.databricks.com/en/dev-tools/databricks-connect/python/index.html).
7. For documentation on the Databricks asset bundles format used 7. For documentation on the Databricks asset bundles format used
for this project, and for CI/CD configuration, see for this project, and for CI/CD configuration, see

View File

@ -1,8 +1,8 @@
package acceptance_test package acceptance_test
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -14,7 +14,7 @@ import (
func StartCmdServer(t *testing.T) *testserver.Server { func StartCmdServer(t *testing.T) *testserver.Server {
server := testserver.New(t) server := testserver.New(t)
server.Handle("GET", "/", func(_ *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/", func(r testserver.Request) any {
q := r.URL.Query() q := r.URL.Query()
args := strings.Split(q.Get("args"), " ") args := strings.Split(q.Get("args"), " ")
@ -27,7 +27,7 @@ func StartCmdServer(t *testing.T) *testserver.Server {
defer Chdir(t, q.Get("cwd"))() defer Chdir(t, q.Get("cwd"))()
c := testcli.NewRunner(t, r.Context(), args...) c := testcli.NewRunner(t, context.Background(), args...)
c.Verbose = false c.Verbose = false
stdout, stderr, err := c.Run() stdout, stderr, err := c.Run()
result := map[string]any{ result := map[string]any{
@ -39,7 +39,7 @@ func StartCmdServer(t *testing.T) *testserver.Server {
exitcode = 1 exitcode = 1
} }
result["exitcode"] = exitcode result["exitcode"] = exitcode
return result, http.StatusOK return result
}) })
return server return server
} }

View File

@ -10,6 +10,7 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/databricks/cli/libs/testdiff" "github.com/databricks/cli/libs/testdiff"
"github.com/databricks/cli/libs/testserver"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -56,10 +57,7 @@ type ServerStub struct {
Pattern string Pattern string
// The response body to return. // The response body to return.
Response struct { Response testserver.Response
Body string
StatusCode int
}
} }
// FindConfigs finds all the config relevant for this test, // FindConfigs finds all the config relevant for this test,

View File

@ -0,0 +1,7 @@
Hello!
{
"id": "[USERID]",
"userName": "[USERNAME]"
}
Footer

View File

@ -0,0 +1,7 @@
Hello!
{
"id": "[UUID]",
"userName": "[USERNAME]"
}
Footer

View File

@ -0,0 +1,13 @@
>>> diff.py out_dir_a out_dir_b
Only in out_dir_a: only_in_a
Only in out_dir_b: only_in_b
--- out_dir_a/output.txt
+++ out_dir_b/output.txt
@@ -1,5 +1,5 @@
Hello!
{
- "id": "[USERID]",
+ "id": "[UUID]",
"userName": "[USERNAME]"
}

View File

@ -0,0 +1,17 @@
mkdir out_dir_a
mkdir out_dir_b
touch out_dir_a/only_in_a
touch out_dir_b/only_in_b
echo Hello! >> out_dir_a/output.txt
echo Hello! >> out_dir_b/output.txt
curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me >> out_dir_a/output.txt
printf "\n\nFooter" >> out_dir_a/output.txt
printf '{\n "id": "7d639bad-ac6d-4e6f-abd7-9522a86b0239",\n "userName": "[USERNAME]"\n}\n\nFooter' >> out_dir_b/output.txt
# Unlike regular diff, diff.py will apply replacements first before doing the comparison
errcode trace diff.py out_dir_a out_dir_b
rm out_dir_a/only_in_a out_dir_b/only_in_b

View File

@ -6,3 +6,7 @@
"method": "GET", "method": "GET",
"path": "/custom/endpoint" "path": "/custom/endpoint"
} }
{
"method": "GET",
"path": "/api/2.0/workspace/get-status"
}

View File

@ -6,10 +6,16 @@
} }
>>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param >>> curl -sD - [DATABRICKS_URL]/custom/endpoint?query=param
HTTP/1.1 201 Created HTTP/1.1 201 Created
Content-Type: application/json X-Custom-Header: hello
Date: (redacted) Date: (redacted)
Content-Length: (redacted) Content-Length: (redacted)
Content-Type: text/plain; charset=utf-8
custom custom
--- ---
response response
>>> errcode [CLI] workspace get-status /a/b/c
Error: Workspace path not found
Exit code: 1

View File

@ -1,2 +1,4 @@
trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me trace curl -s $DATABRICKS_HOST/api/2.0/preview/scim/v2/Me
trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param trace curl -sD - $DATABRICKS_HOST/custom/endpoint?query=param
trace errcode $CLI workspace get-status /a/b/c

View File

@ -1,6 +1,10 @@
LocalOnly = true LocalOnly = true
RecordRequests = true RecordRequests = true
[[Server]]
Pattern = "GET /custom/endpoint"
Response.Body = '''should not see this response, latter response takes precedence'''
[[Server]] [[Server]]
Pattern = "GET /custom/endpoint" Pattern = "GET /custom/endpoint"
Response.Body = '''custom Response.Body = '''custom
@ -8,6 +12,8 @@ Response.Body = '''custom
response response
''' '''
Response.StatusCode = 201 Response.StatusCode = 201
[Server.Response.Headers]
"X-Custom-Header" = ["hello"]
[[Repls]] [[Repls]]
Old = 'Date: .*' Old = 'Date: .*'

View File

@ -0,0 +1 @@
LocalOnly = true

View File

@ -1,14 +1,12 @@
package acceptance_test package acceptance_test
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/catalog"
"github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/iam"
"github.com/gorilla/mux"
"github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/jobs"
@ -23,7 +21,7 @@ var testUser = iam.User{
} }
func AddHandlers(server *testserver.Server) { func AddHandlers(server *testserver.Server) {
server.Handle("GET", "/api/2.0/policies/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/policies/clusters/list", func(req testserver.Request) any {
return compute.ListPoliciesResponse{ return compute.ListPoliciesResponse{
Policies: []compute.Policy{ Policies: []compute.Policy{
{ {
@ -35,10 +33,10 @@ func AddHandlers(server *testserver.Server) {
Name: "some-test-cluster-policy", Name: "some-test-cluster-policy",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.0/instance-pools/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/instance-pools/list", func(req testserver.Request) any {
return compute.ListInstancePools{ return compute.ListInstancePools{
InstancePools: []compute.InstancePoolAndStats{ InstancePools: []compute.InstancePoolAndStats{
{ {
@ -46,10 +44,10 @@ func AddHandlers(server *testserver.Server) {
InstancePoolId: "1234", InstancePoolId: "1234",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.1/clusters/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/clusters/list", func(req testserver.Request) any {
return compute.ListClustersResponse{ return compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{ Clusters: []compute.ClusterDetails{
{ {
@ -61,74 +59,60 @@ func AddHandlers(server *testserver.Server) {
ClusterId: "9876", ClusterId: "9876",
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/preview/scim/v2/Me", func(req testserver.Request) any {
return testUser, http.StatusOK return testserver.Response{
Headers: map[string][]string{"X-Databricks-Org-Id": {"900800700600"}},
Body: testUser,
}
}) })
server.Handle("GET", "/api/2.0/workspace/get-status", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/workspace/get-status", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
return req.Workspace.WorkspaceGetStatus(path)
return fakeWorkspace.WorkspaceGetStatus(path)
}) })
server.Handle("POST", "/api/2.0/workspace/mkdirs", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace/mkdirs", func(req testserver.Request) any {
request := workspace.Mkdirs{} var request workspace.Mkdirs
decoder := json.NewDecoder(r.Body) if err := json.Unmarshal(req.Body, &request); err != nil {
return testserver.Response{
err := decoder.Decode(&request) Body: fmt.Sprintf("internal error: %s", err),
if err != nil { StatusCode: http.StatusInternalServerError,
return internalError(err) }
} }
return fakeWorkspace.WorkspaceMkdirs(request) req.Workspace.WorkspaceMkdirs(request)
return ""
}) })
server.Handle("GET", "/api/2.0/workspace/export", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/workspace/export", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
return req.Workspace.WorkspaceExport(path)
return fakeWorkspace.WorkspaceExport(path)
}) })
server.Handle("POST", "/api/2.0/workspace/delete", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace/delete", func(req testserver.Request) any {
path := r.URL.Query().Get("path") path := req.URL.Query().Get("path")
recursiveStr := r.URL.Query().Get("recursive") recursive := req.URL.Query().Get("recursive") == "true"
var recursive bool req.Workspace.WorkspaceDelete(path, recursive)
return ""
if recursiveStr == "true" {
recursive = true
} else {
recursive = false
}
return fakeWorkspace.WorkspaceDelete(path, recursive)
}) })
server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(req testserver.Request) any {
vars := mux.Vars(r) path := req.Vars["path"]
path := vars["path"] req.Workspace.WorkspaceFilesImportFile(path, req.Body)
return ""
body := new(bytes.Buffer)
_, err := body.ReadFrom(r.Body)
if err != nil {
return internalError(err)
}
return fakeWorkspace.WorkspaceFilesImportFile(path, body.Bytes())
}) })
server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/unity-catalog/current-metastore-assignment", func(req testserver.Request) any {
return catalog.MetastoreAssignment{ return catalog.MetastoreAssignment{
DefaultCatalogName: "main", DefaultCatalogName: "main",
}, http.StatusOK }
}) })
server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.0/permissions/directories/{objectId}", func(req testserver.Request) any {
vars := mux.Vars(r) objectId := req.Vars["objectId"]
objectId := vars["objectId"]
return workspace.WorkspaceObjectPermissions{ return workspace.WorkspaceObjectPermissions{
ObjectId: objectId, ObjectId: objectId,
ObjectType: "DIRECTORY", ObjectType: "DIRECTORY",
@ -142,48 +126,43 @@ func AddHandlers(server *testserver.Server) {
}, },
}, },
}, },
}, http.StatusOK }
}) })
server.Handle("POST", "/api/2.1/jobs/create", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/api/2.1/jobs/create", func(req testserver.Request) any {
request := jobs.CreateJob{} var request jobs.CreateJob
decoder := json.NewDecoder(r.Body) if err := json.Unmarshal(req.Body, &request); err != nil {
return testserver.Response{
err := decoder.Decode(&request) Body: fmt.Sprintf("internal error: %s", err),
if err != nil { StatusCode: 500,
return internalError(err) }
} }
return fakeWorkspace.JobsCreate(request) return req.Workspace.JobsCreate(request)
}) })
server.Handle("GET", "/api/2.1/jobs/get", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/jobs/get", func(req testserver.Request) any {
jobId := r.URL.Query().Get("job_id") jobId := req.URL.Query().Get("job_id")
return req.Workspace.JobsGet(jobId)
return fakeWorkspace.JobsGet(jobId)
}) })
server.Handle("GET", "/api/2.1/jobs/list", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/api/2.1/jobs/list", func(req testserver.Request) any {
return fakeWorkspace.JobsList() return req.Workspace.JobsList()
}) })
server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("GET", "/oidc/.well-known/oauth-authorization-server", func(_ testserver.Request) any {
return map[string]string{ return map[string]string{
"authorization_endpoint": server.URL + "oidc/v1/authorize", "authorization_endpoint": server.URL + "oidc/v1/authorize",
"token_endpoint": server.URL + "/oidc/v1/token", "token_endpoint": server.URL + "/oidc/v1/token",
}, http.StatusOK }
}) })
server.Handle("POST", "/oidc/v1/token", func(fakeWorkspace *testserver.FakeWorkspace, r *http.Request) (any, int) { server.Handle("POST", "/oidc/v1/token", func(_ testserver.Request) any {
return map[string]string{ return map[string]string{
"access_token": "oauth-token", "access_token": "oauth-token",
"expires_in": "3600", "expires_in": "3600",
"scope": "all-apis", "scope": "all-apis",
"token_type": "Bearer", "token_type": "Bearer",
}, http.StatusOK }
}) })
} }
func internalError(err error) (any, int) {
return fmt.Errorf("internal error: %w", err), http.StatusInternalServerError
}

View File

@ -92,7 +92,7 @@ func expandLibraries(b *bundle.Bundle, p dyn.Path, v dyn.Value) (diag.Diagnostic
for _, match := range matches { for _, match := range matches {
output = append(output, dyn.NewValue(map[string]dyn.Value{ output = append(output, dyn.NewValue(map[string]dyn.Value{
libType: dyn.NewValue(match, lib.Locations()), libType: dyn.V(match),
}, lib.Locations())) }, lib.Locations()))
} }
} }

View File

@ -1,97 +0,0 @@
package libraries
import (
"context"
"path/filepath"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
)
type checkForSameNameLibraries struct{}
var patterns = []dyn.Pattern{
taskLibrariesPattern.Append(dyn.AnyIndex(), dyn.AnyKey()),
forEachTaskLibrariesPattern.Append(dyn.AnyIndex(), dyn.AnyKey()),
envDepsPattern.Append(dyn.AnyIndex()),
}
type libData struct {
fullPath string
locations []dyn.Location
paths []dyn.Path
}
func (c checkForSameNameLibraries) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
var diags diag.Diagnostics
libs := make(map[string]*libData)
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
var err error
for _, pattern := range patterns {
v, err = dyn.MapByPattern(v, pattern, func(p dyn.Path, lv dyn.Value) (dyn.Value, error) {
libPath := lv.MustString()
// If not local library, skip the check
if !IsLibraryLocal(libPath) {
return lv, nil
}
libFullPath := lv.MustString()
lib := filepath.Base(libFullPath)
// If the same basename was seen already but full path is different
// then it's a duplicate. Add the location to the location list.
lp, ok := libs[lib]
if !ok {
libs[lib] = &libData{
fullPath: libFullPath,
locations: []dyn.Location{lv.Location()},
paths: []dyn.Path{p},
}
} else if lp.fullPath != libFullPath {
lp.locations = append(lp.locations, lv.Location())
lp.paths = append(lp.paths, p)
}
return lv, nil
})
if err != nil {
return dyn.InvalidValue, err
}
}
if err != nil {
return dyn.InvalidValue, err
}
return v, nil
})
// Iterate over all the libraries and check if there are any duplicates.
// Duplicates will have more than one location.
// If there are duplicates, add a diagnostic.
for lib, lv := range libs {
if len(lv.locations) > 1 {
diags = append(diags, diag.Diagnostic{
Severity: diag.Error,
Summary: "Duplicate local library name " + lib,
Detail: "Local library names must be unique",
Locations: lv.locations,
Paths: lv.paths,
})
}
}
if err != nil {
diags = diags.Extend(diag.FromErr(err))
}
return diags
}
func (c checkForSameNameLibraries) Name() string {
return "CheckForSameNameLibraries"
}
func CheckForSameNameLibraries() bundle.Mutator {
return checkForSameNameLibraries{}
}

View File

@ -1,121 +0,0 @@
package libraries
import (
"context"
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/internal/bundletest"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/require"
)
func TestSameNameLibraries(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
Libraries: []compute.Library{
{
Whl: "full/path/test.whl",
},
},
},
{
Libraries: []compute.Library{
{
Whl: "other/path/test.whl",
},
},
},
},
},
},
},
},
},
}
bundletest.SetLocation(b, "resources.jobs.test.tasks[0]", []dyn.Location{
{File: "databricks.yml", Line: 10, Column: 1},
})
bundletest.SetLocation(b, "resources.jobs.test.tasks[1]", []dyn.Location{
{File: "databricks.yml", Line: 20, Column: 1},
})
diags := bundle.Apply(context.Background(), b, CheckForSameNameLibraries())
require.Len(t, diags, 1)
require.Equal(t, diag.Error, diags[0].Severity)
require.Equal(t, "Duplicate local library name test.whl", diags[0].Summary)
require.Equal(t, []dyn.Location{
{File: "databricks.yml", Line: 10, Column: 1},
{File: "databricks.yml", Line: 20, Column: 1},
}, diags[0].Locations)
paths := make([]string, 0)
for _, p := range diags[0].Paths {
paths = append(paths, p.String())
}
require.Equal(t, []string{
"resources.jobs.test.tasks[0].libraries[0].whl",
"resources.jobs.test.tasks[1].libraries[0].whl",
}, paths)
}
func TestSameNameLibrariesWithUniqueLibraries(t *testing.T) {
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"test": {
JobSettings: &jobs.JobSettings{
Tasks: []jobs.Task{
{
Libraries: []compute.Library{
{
Whl: "full/path/test-0.1.1.whl",
},
{
Whl: "cowsay",
},
},
},
{
Libraries: []compute.Library{
{
Whl: "other/path/test-0.1.0.whl",
},
{
Whl: "cowsay",
},
},
},
{
Libraries: []compute.Library{
{
Whl: "full/path/test-0.1.1.whl", // Use the same library as the first task
},
},
},
},
},
},
},
},
},
}
diags := bundle.Apply(context.Background(), b, CheckForSameNameLibraries())
require.Empty(t, diags)
}

View File

@ -155,11 +155,6 @@ func Deploy(outputHandler sync.OutputHandler) bundle.Mutator {
mutator.ValidateGitDetails(), mutator.ValidateGitDetails(),
artifacts.CleanUp(), artifacts.CleanUp(),
libraries.ExpandGlobReferences(), libraries.ExpandGlobReferences(),
// libraries.CheckForSameNameLibraries() needs to be run after we expand glob references so we
// know what are the actual library paths.
// libraries.ExpandGlobReferences() has to be run after the libraries are built and thus this
// mutator is part of the deploy step rather than validate.
libraries.CheckForSameNameLibraries(),
libraries.Upload(), libraries.Upload(),
trampoline.TransformWheelTask(), trampoline.TransformWheelTask(),
files.Upload(outputHandler), files.Upload(outputHandler),

View File

@ -38,10 +38,16 @@ The '{{.project_name}}' project was generated by using the default-python templa
$ databricks bundle run $ databricks bundle run
``` ```
{{- if (eq .include_python "no") }}
6. Optionally, install developer tools such as the Databricks extension for Visual Studio Code from 6. Optionally, install developer tools such as the Databricks extension for Visual Studio Code from
https://docs.databricks.com/dev-tools/vscode-ext.html. https://docs.databricks.com/dev-tools/vscode-ext.html.
{{- if (eq .include_python "yes") }} Or read the "getting started" documentation for {{- else }}
**Databricks Connect** for instructions on running the included Python code from a different IDE. 6. Optionally, install the Databricks extension for Visual Studio code for local development from
https://docs.databricks.com/dev-tools/vscode-ext.html. It can configure your
virtual environment and setup Databricks Connect for running unit tests locally.
When not using these tools, consult your development environment's documentation
and/or the documentation for Databricks Connect for manually setting up your environment
(https://docs.databricks.com/en/dev-tools/databricks-connect/python/index.html).
{{- end}} {{- end}}
7. For documentation on the Databricks asset bundles format used 7. For documentation on the Databricks asset bundles format used

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -33,40 +32,39 @@ func NewFakeWorkspace() *FakeWorkspace {
} }
} }
func (s *FakeWorkspace) WorkspaceGetStatus(path string) (workspace.ObjectInfo, int) { func (s *FakeWorkspace) WorkspaceGetStatus(path string) Response {
if s.directories[path] { if s.directories[path] {
return workspace.ObjectInfo{ return Response{
ObjectType: "DIRECTORY", Body: &workspace.ObjectInfo{
Path: path, ObjectType: "DIRECTORY",
}, http.StatusOK Path: path,
},
}
} else if _, ok := s.files[path]; ok { } else if _, ok := s.files[path]; ok {
return workspace.ObjectInfo{ return Response{
ObjectType: "FILE", Body: &workspace.ObjectInfo{
Path: path, ObjectType: "FILE",
Language: "SCALA", Path: path,
}, http.StatusOK Language: "SCALA",
},
}
} else { } else {
return workspace.ObjectInfo{}, http.StatusNotFound return Response{
StatusCode: 404,
Body: map[string]string{"message": "Workspace path not found"},
}
} }
} }
func (s *FakeWorkspace) WorkspaceMkdirs(request workspace.Mkdirs) (string, int) { func (s *FakeWorkspace) WorkspaceMkdirs(request workspace.Mkdirs) {
s.directories[request.Path] = true s.directories[request.Path] = true
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceExport(path string) ([]byte, int) { func (s *FakeWorkspace) WorkspaceExport(path string) []byte {
file := s.files[path] return s.files[path]
if file == nil {
return nil, http.StatusNotFound
}
return file, http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) (string, int) { func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) {
if !recursive { if !recursive {
s.files[path] = nil s.files[path] = nil
} else { } else {
@ -76,28 +74,26 @@ func (s *FakeWorkspace) WorkspaceDelete(path string, recursive bool) (string, in
} }
} }
} }
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) WorkspaceFilesImportFile(path string, body []byte) (any, int) { func (s *FakeWorkspace) WorkspaceFilesImportFile(path string, body []byte) {
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
path = "/" + path path = "/" + path
} }
s.files[path] = body s.files[path] = body
return "{}", http.StatusOK
} }
func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) (any, int) { func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) Response {
jobId := s.nextJobId jobId := s.nextJobId
s.nextJobId++ s.nextJobId++
jobSettings := jobs.JobSettings{} jobSettings := jobs.JobSettings{}
err := jsonConvert(request, &jobSettings) err := jsonConvert(request, &jobSettings)
if err != nil { if err != nil {
return internalError(err) return Response{
StatusCode: 400,
Body: fmt.Sprintf("Cannot convert request to jobSettings: %s", err),
}
} }
s.jobs[jobId] = jobs.Job{ s.jobs[jobId] = jobs.Job{
@ -105,32 +101,44 @@ func (s *FakeWorkspace) JobsCreate(request jobs.CreateJob) (any, int) {
Settings: &jobSettings, Settings: &jobSettings,
} }
return jobs.CreateResponse{JobId: jobId}, http.StatusOK return Response{
Body: jobs.CreateResponse{JobId: jobId},
}
} }
func (s *FakeWorkspace) JobsGet(jobId string) (any, int) { func (s *FakeWorkspace) JobsGet(jobId string) Response {
id := jobId id := jobId
jobIdInt, err := strconv.ParseInt(id, 10, 64) jobIdInt, err := strconv.ParseInt(id, 10, 64)
if err != nil { if err != nil {
return internalError(fmt.Errorf("failed to parse job id: %s", err)) return Response{
StatusCode: 400,
Body: fmt.Sprintf("Failed to parse job id: %s: %v", err, id),
}
} }
job, ok := s.jobs[jobIdInt] job, ok := s.jobs[jobIdInt]
if !ok { if !ok {
return jobs.Job{}, http.StatusNotFound return Response{
StatusCode: 404,
}
} }
return job, http.StatusOK return Response{
Body: job,
}
} }
func (s *FakeWorkspace) JobsList() (any, int) { func (s *FakeWorkspace) JobsList() Response {
list := make([]jobs.BaseJob, 0, len(s.jobs)) list := make([]jobs.BaseJob, 0, len(s.jobs))
for _, job := range s.jobs { for _, job := range s.jobs {
baseJob := jobs.BaseJob{} baseJob := jobs.BaseJob{}
err := jsonConvert(job, &baseJob) err := jsonConvert(job, &baseJob)
if err != nil { if err != nil {
return internalError(fmt.Errorf("failed to convert job to base job: %w", err)) return Response{
StatusCode: 400,
Body: fmt.Sprintf("failed to convert job to base job: %s", err),
}
} }
list = append(list, baseJob) list = append(list, baseJob)
@ -141,9 +149,11 @@ func (s *FakeWorkspace) JobsList() (any, int) {
return list[i].JobId < list[j].JobId return list[i].JobId < list[j].JobId
}) })
return jobs.ListJobsResponse{ return Response{
Jobs: list, Body: jobs.ListJobsResponse{
}, http.StatusOK Jobs: list,
},
}
} }
// jsonConvert saves input to a value pointed by output // jsonConvert saves input to a value pointed by output
@ -163,7 +173,3 @@ func jsonConvert(input, output any) error {
return nil return nil
} }
func internalError(err error) (string, int) {
return fmt.Sprintf("internal error: %s", err), http.StatusInternalServerError
}

View File

@ -5,14 +5,14 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"reflect"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/internal/testutil"
"github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/apierr"
) )
@ -29,10 +29,10 @@ type Server struct {
RecordRequests bool RecordRequests bool
IncludeRequestHeaders []string IncludeRequestHeaders []string
Requests []Request Requests []LoggedRequest
} }
type Request struct { type LoggedRequest struct {
Headers http.Header `json:"headers,omitempty"` Headers http.Header `json:"headers,omitempty"`
Method string `json:"method"` Method string `json:"method"`
Path string `json:"path"` Path string `json:"path"`
@ -40,6 +40,153 @@ type Request struct {
RawBody string `json:"raw_body,omitempty"` RawBody string `json:"raw_body,omitempty"`
} }
type Request struct {
Method string
URL *url.URL
Headers http.Header
Body []byte
Vars map[string]string
Workspace *FakeWorkspace
}
type Response struct {
StatusCode int
Headers http.Header
Body any
}
type encodedResponse struct {
StatusCode int
Headers http.Header
Body []byte
}
func NewRequest(t testutil.TestingT, r *http.Request, fakeWorkspace *FakeWorkspace) Request {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %s", err)
}
return Request{
Method: r.Method,
URL: r.URL,
Headers: r.Header,
Body: body,
Vars: mux.Vars(r),
Workspace: fakeWorkspace,
}
}
func normalizeResponse(t testutil.TestingT, resp any) encodedResponse {
result := normalizeResponseBody(t, resp)
if result.StatusCode == 0 {
result.StatusCode = 200
}
return result
}
func normalizeResponseBody(t testutil.TestingT, resp any) encodedResponse {
if isNil(resp) {
t.Errorf("Handler must not return nil")
return encodedResponse{StatusCode: 500}
}
respBytes, ok := resp.([]byte)
if ok {
return encodedResponse{
Body: respBytes,
Headers: getHeaders(respBytes),
}
}
respString, ok := resp.(string)
if ok {
return encodedResponse{
Body: []byte(respString),
Headers: getHeaders([]byte(respString)),
}
}
respStruct, ok := resp.(Response)
if ok {
if isNil(respStruct.Body) {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: []byte{},
}
}
bytesVal, isBytes := respStruct.Body.([]byte)
if isBytes {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: bytesVal,
}
}
stringVal, isString := respStruct.Body.(string)
if isString {
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: respStruct.Headers,
Body: []byte(stringVal),
}
}
respBytes, err := json.MarshalIndent(respStruct.Body, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
headers := respStruct.Headers
if headers == nil {
headers = getJsonHeaders()
}
return encodedResponse{
StatusCode: respStruct.StatusCode,
Headers: headers,
Body: respBytes,
}
}
respBytes, err := json.MarshalIndent(resp, "", " ")
if err != nil {
t.Errorf("JSON encoding error: %s", err)
return encodedResponse{
StatusCode: 500,
Body: []byte("internal error"),
}
}
return encodedResponse{
Body: respBytes,
Headers: getJsonHeaders(),
}
}
func getJsonHeaders() http.Header {
return map[string][]string{
"Content-Type": {"application/json"},
}
}
func getHeaders(value []byte) http.Header {
if json.Valid(value) {
return getJsonHeaders()
} else {
return map[string][]string{
"Content-Type": {"text/plain"},
}
}
}
func New(t testutil.TestingT) *Server { func New(t testutil.TestingT) *Server {
router := mux.NewRouter() router := mux.NewRouter()
server := httptest.NewServer(router) server := httptest.NewServer(router)
@ -96,7 +243,7 @@ Response.StatusCode = <response status-code here>
return s return s
} }
type HandlerFunc func(fakeWorkspace *FakeWorkspace, req *http.Request) (resp any, statusCode int) type HandlerFunc func(req Request) any
func (s *Server) Handle(method, path string, handler HandlerFunc) { func (s *Server) Handle(method, path string, handler HandlerFunc) {
s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
@ -117,56 +264,22 @@ func (s *Server) Handle(method, path string, handler HandlerFunc) {
fakeWorkspace = s.fakeWorkspaces[token] fakeWorkspace = s.fakeWorkspaces[token]
} }
resp, statusCode := handler(fakeWorkspace, r) request := NewRequest(s.t, r, fakeWorkspace)
if s.RecordRequests { if s.RecordRequests {
body, err := io.ReadAll(r.Body) s.Requests = append(s.Requests, getLoggedRequest(request, s.IncludeRequestHeaders))
assert.NoError(s.t, err)
headers := make(http.Header)
for k, v := range r.Header {
if !slices.Contains(s.IncludeRequestHeaders, k) {
continue
}
for _, vv := range v {
headers.Add(k, vv)
}
}
req := Request{
Headers: headers,
Method: r.Method,
Path: r.URL.Path,
}
if json.Valid(body) {
req.Body = json.RawMessage(body)
} else {
req.RawBody = string(body)
}
s.Requests = append(s.Requests, req)
} }
w.Header().Set("Content-Type", "application/json") respAny := handler(request)
w.WriteHeader(statusCode) resp := normalizeResponse(s.t, respAny)
var respBytes []byte for k, v := range resp.Headers {
var err error w.Header()[k] = v
if respString, ok := resp.(string); ok {
respBytes = []byte(respString)
} else if respBytes0, ok := resp.([]byte); ok {
respBytes = respBytes0
} else {
respBytes, err = json.MarshalIndent(resp, "", " ")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} }
if _, err := w.Write(respBytes); err != nil { w.WriteHeader(resp.StatusCode)
http.Error(w, err.Error(), http.StatusInternalServerError)
if _, err := w.Write(resp.Body); err != nil {
s.t.Errorf("Failed to write response: %s", err)
return return
} }
}).Methods(method) }).Methods(method)
@ -182,3 +295,43 @@ func getToken(r *http.Request) string {
return header[len(prefix):] return header[len(prefix):]
} }
func getLoggedRequest(req Request, includedHeaders []string) LoggedRequest {
result := LoggedRequest{
Method: req.Method,
Path: req.URL.Path,
Headers: filterHeaders(req.Headers, includedHeaders),
}
if json.Valid(req.Body) {
result.Body = json.RawMessage(req.Body)
} else {
result.RawBody = string(req.Body)
}
return result
}
func filterHeaders(h http.Header, includedHeaders []string) http.Header {
headers := make(http.Header)
for k, v := range h {
if !slices.Contains(includedHeaders, k) {
continue
}
headers[k] = v
}
return headers
}
func isNil(i any) bool {
if i == nil {
return true
}
v := reflect.ValueOf(i)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice:
return v.IsNil()
default:
return false
}
}