calculate content length before upload

This commit is contained in:
Shreyas Goenka 2025-01-02 17:20:22 +05:30
parent 890b48f70d
commit f70c47253e
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
1 changed files with 44 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package filer package filer
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -99,6 +100,36 @@ func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
}, nil }, nil
} }
// The PUT API for DBFS requires setting the content length header beforehand in the HTTP
// request.
func putContentLength(path string, overwriteField string, file *os.File) (int64, error) {
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
err := writer.WriteField("path", path)
if err != nil {
return 0, fmt.Errorf("failed to write field path field in multipart form: %w", err)
}
err = writer.WriteField("overwrite", overwriteField)
if err != nil {
return 0, fmt.Errorf("failed to write field overwrite field in multipart form: %w", err)
}
_, err = writer.CreateFormFile("contents", "")
if err != nil {
return 0, fmt.Errorf("failed to write contents field in multipart form: %w", err)
}
err = writer.Close()
if err != nil {
return 0, fmt.Errorf("failed to close multipart form writer: %w", err)
}
stat, err := file.Stat()
if err != nil {
return 0, fmt.Errorf("failed to stat file %s: %w", path, err)
}
return int64(buf.Len()) + stat.Size(), nil
}
func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error { func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error {
overwriteField := "False" overwriteField := "False"
if overwrite { if overwrite {
@ -112,17 +143,17 @@ func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, f
err := writer.WriteField("path", path) err := writer.WriteField("path", path)
if err != nil { if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field path field in multipath form: %w", err)) pw.CloseWithError(fmt.Errorf("failed to write field path field in multipart form: %w", err))
return return
} }
err = writer.WriteField("overwrite", overwriteField) err = writer.WriteField("overwrite", overwriteField)
if err != nil { if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipath form: %w", err)) pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipart form: %w", err))
return return
} }
contents, err := writer.CreateFormFile("contents", "") contents, err := writer.CreateFormFile("contents", "")
if err != nil { if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write contents field in multipath form: %w", err)) pw.CloseWithError(fmt.Errorf("failed to write contents field in multipart form: %w", err))
return return
} }
for { for {
@ -133,21 +164,27 @@ func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, f
break break
} }
if err != nil { if err != nil {
pw.CloseWithError(fmt.Errorf("failed to copy file in multipath form: %w", err)) pw.CloseWithError(fmt.Errorf("failed to copy file in multipart form: %w", err))
return return
} }
} }
err = writer.Close() err = writer.Close()
if err != nil { if err != nil {
pw.CloseWithError(fmt.Errorf("failed to close multipath form writer: %w", err)) pw.CloseWithError(fmt.Errorf("failed to close multipart form writer: %w", err))
return return
} }
}() }()
cl, err := putContentLength(path, overwriteField, file)
if err != nil {
return fmt.Errorf("failed to calculate content length: %w", err)
}
// Request bodies of Content-Type multipart/form-data are not supported by // Request bodies of Content-Type multipart/form-data are not supported by
// the Go SDK directly for DBFS. So we use the Do method directly. // the Go SDK directly for DBFS. So we use the Do method directly.
err := w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{ err = w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{
"Content-Type": writer.FormDataContentType(), "Content-Type": writer.FormDataContentType(),
"Content-Length": fmt.Sprintf("%d", cl),
}, pr, nil) }, pr, nil)
var aerr *apierr.APIError var aerr *apierr.APIError
if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" { if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {