diff --git a/acceptance/acceptance_test.go b/acceptance/acceptance_test.go index 302a1b50d..5f53f8d41 100644 --- a/acceptance/acceptance_test.go +++ b/acceptance/acceptance_test.go @@ -6,6 +6,7 @@ import ( "errors" "flag" "fmt" + "github.com/google/uuid" "io" "net/http" "os" @@ -20,8 +21,6 @@ import ( "time" "unicode/utf8" - "github.com/google/uuid" - "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/testdiff" @@ -41,7 +40,7 @@ var ( // In order to debug CLI running under acceptance test, set this to full subtest name, e.g. "bundle/variables/empty" // Then install your breakpoints and click "debug test" near TestAccept in VSCODE. // example: var SingleTest = "bundle/variables/empty" -var SingleTest = "" +var SingleTest = "bundle/deployment/bind/schema" // If enabled, instead of compiling and running CLI externally, we'll start in-process server that accepts and runs // CLI commands. The $CLI in test scripts is a helper that just forwards command-line arguments to this server (see bin/callserver.py). @@ -58,7 +57,7 @@ const ( EntryPointScript = "script" CleanupScript = "script.cleanup" PrepareScript = "script.prepare" - MaxFileSize = 100_000 + MaxFileSize = 300_000 // Filename to save replacements to (used by diff.py) ReplsFile = "repls.json" ) @@ -77,6 +76,10 @@ func TestAccept(t *testing.T) { testAccept(t, InprocessMode, SingleTest) } +func TestAcceptLocal(t *testing.T) { + testAccept(t, InprocessMode, SingleTest) +} + func TestInprocessMode(t *testing.T) { if InprocessMode { t.Skip("Already tested by TestAccept") @@ -223,10 +226,9 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont if !isTruePtr(config.Cloud) && cloudEnv != "" { t.Skipf("Disabled via Cloud setting in %s (CLOUD_ENV=%s)", configPath, cloudEnv) - } else { - if isTruePtr(config.RequiresUnityCatalog) && os.Getenv("TEST_METASTORE_ID") == "" { - t.Skipf("Skipping on non-UC workspaces") - } + } + if cloudEnv != "" && isTruePtr(config.RequiresUnityCatalog) && os.Getenv("TEST_METASTORE_ID") == "" { + t.Skipf("Skipping on non-UC workspaces") } var tmpDir string @@ -334,6 +336,22 @@ func runTest(t *testing.T, dir, coverDir string, repls testdiff.ReplacementsCont user = *pUser } + if cloudEnv != "" && isTruePtr(config.RecordRequests) { + // Start a new recording proxy for this test + logPath := filepath.Join(tmpDir, "out.request-recordings.txt") + targetHost := os.Getenv("DATABRICKS_HOST") + proxyServer, err := testserver.NewProxyRecorder(targetHost, logPath) + t.Logf("Starting a recording proxy server; proxies to " + targetHost + "; recording requests to " + logPath) + if err != nil { + t.Fatalf("Failed to create proxy server: %v", err) + } + t.Cleanup(proxyServer.Close) + + t.Logf("Setting DATABRICKS_DEFAULT_HOST to " + proxyServer.URL()) + t.Setenv("DATABRICKS_DEFAULT_HOST", proxyServer.URL()) + cmd.Env = append(cmd.Env, "DATABRICKS_HOST="+proxyServer.URL()) + } + testdiff.PrepareReplacementsUser(t, &repls, user) testdiff.PrepareReplacementsWorkspaceClient(t, &repls, workspaceClient) diff --git a/acceptance/bundle/deployment/bind/schema/test.toml b/acceptance/bundle/deployment/bind/schema/test.toml index 46518d61e..258fb83b0 100644 --- a/acceptance/bundle/deployment/bind/schema/test.toml +++ b/acceptance/bundle/deployment/bind/schema/test.toml @@ -1,3 +1,32 @@ -Local = false +Local = true Cloud = true RequiresUnityCatalog = true +RecordRequests = true + +[[Server]] +Pattern = "POST /api/2.1/unity-catalog/schemas" +Response.Body = ''' +{ + "name":"test-schema-6260d50f-e8ff-4905-8f28-812345678903", + "catalog_name":"main", + "enable_auto_maintenance":"INHERIT", + "enable_predictive_optimization":"INHERIT", + "full_name":"main.test-schema-6260d50f-e8ff-4905-8f28-812345678903", + "created_at":1741363224990, + "created_by":"[USERNAME]", + "updated_at":1741363224990, + "updated_by":"[USERNAME]", + "catalog_type":"MANAGED_CATALOG", + "schema_id":"6260d50f-e8ff-4905-8f28-812345678903", + "securable_type":"SCHEMA", + "securable_kind":"SCHEMA_STANDARD", + "browse_only":false, + "metastore_version":-1 +} +''' + +[[Server]] +Pattern = "GET /api/2.1/unity-catalog/schemas/{schema_fullname}" + +[[Server]] +Pattern = "POST /api/2.1/unity-catalog/schemas" \ No newline at end of file diff --git a/libs/testserver/proxy_recorder.go b/libs/testserver/proxy_recorder.go new file mode 100644 index 000000000..47708098b --- /dev/null +++ b/libs/testserver/proxy_recorder.go @@ -0,0 +1,287 @@ +package testserver + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "time" + "unicode/utf8" +) + +// ProxyRecorder represents a proxy server that records HTTP traffic +type ProxyRecorder struct { + // Remote server to proxy to + RemoteURL string + // File to record traffic to + RecordFile *os.File + // Lock for concurrent writes to the file + mu sync.Mutex + // The test server + Server *httptest.Server +} + +// NewProxyRecorder creates a new proxy recorder +func NewProxyRecorder(remoteURL, recordFilePath string) (*ProxyRecorder, error) { + // Open or create the record file + file, err := os.OpenFile(recordFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open record file: %w", err) + } + + // Create the proxy recorder + pr := &ProxyRecorder{ + RemoteURL: remoteURL, + RecordFile: file, + } + + // Create the test server + pr.Server = httptest.NewServer(http.HandlerFunc(pr.proxyHandler)) + + return pr, nil +} + +// Close stops the server and closes the record file +func (pr *ProxyRecorder) Close() { + pr.Server.Close() + err := pr.RecordFile.Close() + if err != nil { + log.Fatalf("Error closing record file") + } +} + +// URL returns the URL of the test server +func (pr *ProxyRecorder) URL() string { + return pr.Server.URL +} + +// proxyHandler handles HTTP requests, forwards them to the remote server, +// and records both the request and response +func (pr *ProxyRecorder) proxyHandler(w http.ResponseWriter, r *http.Request) { + // Record timestamp + timestamp := time.Now().Format(time.RFC3339) + + // Dump the request to be recorded + var requestDump strings.Builder + requestDump.WriteString(fmt.Sprintf("=== REQUEST %s ===\n", timestamp)) + requestDump.WriteString(fmt.Sprintf("%s %s %s\n", r.Method, r.URL.Path, r.Proto)) + + // Record headers + for name, values := range r.Header { + for _, value := range values { + requestDump.WriteString(fmt.Sprintf("%s: %s\n", name, value)) + } + } + + // Read and record the request body, if any + var requestBody []byte + if r.Body != nil { + var err error + requestBody, err = io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + // Close the original body and replace it with a new reader + r.Body.Close() + r.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + if len(requestBody) > 0 { + requestDump.WriteString("\n") + + // Check content type to handle binary data appropriately + contentType := r.Header.Get("Content-Type") + if isBinaryContent(contentType) { + // For binary content, log length and format instead of raw content + requestDump.WriteString(fmt.Sprintf("[Binary data, %d bytes, Content-Type: %s]\n", + len(requestBody), contentType)) + } else { + // For text content, convert to UTF-8 if needed + bodyStr := safeStringConversion(requestBody) + requestDump.WriteString(bodyStr) + } + } + } + requestDump.WriteString("\n\n") + + // Create a new request to the remote server + remoteURL := pr.RemoteURL + r.URL.Path + if r.URL.RawQuery != "" { + remoteURL += "?" + r.URL.RawQuery + } + + proxyReq, err := http.NewRequest(r.Method, remoteURL, bytes.NewBuffer(requestBody)) + if err != nil { + http.Error(w, "Failed to create proxy request", http.StatusInternalServerError) + return + } + + // Copy headers + for name, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(name, value) + } + } + + // Send the request to the remote server + client := &http.Client{} + resp, err := client.Do(proxyReq) + if err != nil { + http.Error(w, "Failed to proxy request", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Record the response + var responseDump strings.Builder + responseDump.WriteString(fmt.Sprintf("=== RESPONSE %s ===\n", timestamp)) + responseDump.WriteString(fmt.Sprintf("%s %d %s\n", resp.Proto, resp.StatusCode, resp.Status)) + + // Record response headers + for name, values := range resp.Header { + for _, value := range values { + responseDump.WriteString(fmt.Sprintf("%s: %s\n", name, value)) + } + } + + // Read and record the response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + return + } + + var bodyForLogging []byte + isGzipped := false + + // Check if response is gzipped + for _, encoding := range resp.Header.Values("Content-Encoding") { + if strings.Contains(strings.ToLower(encoding), "gzip") { + isGzipped = true + break + } + } + + if isGzipped { + // Decompress for logging + gzipReader, err := gzip.NewReader(bytes.NewReader(responseBody)) + if err != nil { + // If decompression fails, just log that it's compressed + responseDump.WriteString(fmt.Sprintf("\n[Gzipped content, %d bytes]\n", len(responseBody))) + } else { + defer gzipReader.Close() + + // Read the decompressed content + decompressed, err := io.ReadAll(gzipReader) + if err != nil { + responseDump.WriteString(fmt.Sprintf("\n[Error decompressing gzipped content: %v]\n", err)) + } else { + bodyForLogging = decompressed + responseDump.WriteString("\n[Decompressed from gzip]\n") + } + } + } else { + bodyForLogging = responseBody + } + + if len(bodyForLogging) > 0 { + // Check content type to handle binary data appropriately + contentType := resp.Header.Get("Content-Type") + if isBinaryContent(contentType) { + // For binary content, log length and format instead of raw content + responseDump.WriteString(fmt.Sprintf("[Binary data, %d bytes, Content-Type: %s]\n", + len(responseBody), contentType)) + } else { + // For text content, convert to UTF-8 if needed + bodyStr := safeStringConversion(bodyForLogging) + responseDump.WriteString(bodyStr) + } + } + responseDump.WriteString("\n\n") + + // Write the request and response to the record file + pr.mu.Lock() + pr.RecordFile.WriteString(requestDump.String()) + pr.RecordFile.WriteString(responseDump.String()) + pr.mu.Unlock() + + // Copy the response headers to the original response writer + for name, values := range resp.Header { + for _, value := range values { + w.Header().Add(name, value) + } + } + + // Set the status code + w.WriteHeader(resp.StatusCode) + + // Write the response body to the original response writer + w.Write(responseBody) +} + +// ensureValidUTF8 converts a string to valid UTF-8, replacing invalid byte sequences +func ensureValidUTF8(s string) string { + if utf8.ValidString(s) { + return s + } + + // Replace invalid UTF-8 sequences with the Unicode replacement character (U+FFFD) + v := make([]rune, 0, len(s)) + for i, r := range s { + if r == utf8.RuneError { + _, size := utf8.DecodeRuneInString(s[i:]) + if size == 1 { + // Invalid UTF-8 sequence, replace with Unicode replacement character + v = append(v, '\uFFFD') + continue + } + } + v = append(v, r) + } + return string(v) +} + +// isBinaryContent checks if the content type represents binary data +func isBinaryContent(contentType string) bool { + if contentType == "" { + return false + } + + // List of common binary content types + binaryTypes := []string{ + "image/", "audio/", "video/", "application/octet-stream", + "application/pdf", "application/zip", "application/gzip", + "application/x-tar", "application/x-rar-compressed", + "application/x-7z-compressed", "application/x-msdownload", + "application/vnd.ms-", "application/x-ms", + } + + for _, prefix := range binaryTypes { + if strings.HasPrefix(contentType, prefix) { + return true + } + } + + return false +} + +// safeStringConversion attempts to convert bytes to a valid UTF-8 string +func safeStringConversion(data []byte) string { + // First try direct conversion + s := string(data) + + // Check if it's valid UTF-8 + if utf8.ValidString(s) { + return s + } + + // Try to detect encoding and convert to UTF-8 + // For simplicity, we'll just replace invalid characters + return ensureValidUTF8(s) +}