diff --git a/libs/filer/dbfs_client.go b/libs/filer/dbfs_client.go index 268b23d5a..3f6d56b36 100644 --- a/libs/filer/dbfs_client.go +++ b/libs/filer/dbfs_client.go @@ -1,6 +1,7 @@ package filer import ( + "bytes" "context" "errors" "fmt" @@ -99,6 +100,36 @@ func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) { }, nil } +// The PUT API for DBFS requires setting the content length header beforehand in the HTTP +// request. +func putContentLength(path string, overwriteField string, file *os.File) (int64, error) { + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + err := writer.WriteField("path", path) + if err != nil { + return 0, fmt.Errorf("failed to write field path field in multipart form: %w", err) + } + err = writer.WriteField("overwrite", overwriteField) + if err != nil { + return 0, fmt.Errorf("failed to write field overwrite field in multipart form: %w", err) + } + _, err = writer.CreateFormFile("contents", "") + if err != nil { + return 0, fmt.Errorf("failed to write contents field in multipart form: %w", err) + } + err = writer.Close() + if err != nil { + return 0, fmt.Errorf("failed to close multipart form writer: %w", err) + } + + stat, err := file.Stat() + if err != nil { + return 0, fmt.Errorf("failed to stat file %s: %w", path, err) + } + + return int64(buf.Len()) + stat.Size(), nil +} + func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error { overwriteField := "False" if overwrite { @@ -112,17 +143,17 @@ func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, f err := writer.WriteField("path", path) if err != nil { - pw.CloseWithError(fmt.Errorf("failed to write field path field in multipath form: %w", err)) + pw.CloseWithError(fmt.Errorf("failed to write field path field in multipart form: %w", err)) return } err = writer.WriteField("overwrite", overwriteField) if err != nil { - pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipath form: %w", err)) + pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipart form: %w", err)) return } contents, err := writer.CreateFormFile("contents", "") if err != nil { - pw.CloseWithError(fmt.Errorf("failed to write contents field in multipath form: %w", err)) + pw.CloseWithError(fmt.Errorf("failed to write contents field in multipart form: %w", err)) return } for { @@ -133,21 +164,27 @@ func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, f break } if err != nil { - pw.CloseWithError(fmt.Errorf("failed to copy file in multipath form: %w", err)) + pw.CloseWithError(fmt.Errorf("failed to copy file in multipart form: %w", err)) return } } err = writer.Close() if err != nil { - pw.CloseWithError(fmt.Errorf("failed to close multipath form writer: %w", err)) + pw.CloseWithError(fmt.Errorf("failed to close multipart form writer: %w", err)) return } }() + cl, err := putContentLength(path, overwriteField, file) + if err != nil { + return fmt.Errorf("failed to calculate content length: %w", err) + } + // Request bodies of Content-Type multipart/form-data are not supported by // the Go SDK directly for DBFS. So we use the Do method directly. - err := w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{ - "Content-Type": writer.FormDataContentType(), + err = w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{ + "Content-Type": writer.FormDataContentType(), + "Content-Length": fmt.Sprintf("%d", cl), }, pr, nil) var aerr *apierr.APIError if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {