mirror of https://github.com/databricks/cli.git
refactor to make tests easier
This commit is contained in:
parent
91a2dfa0ed
commit
06af01c8f6
|
@ -68,28 +68,37 @@ func (info dbfsFileInfo) Sys() any {
|
||||||
return info.fi
|
return info.fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Interface to allow mocking of the Databricks API client.
|
||||||
|
type databricksClient interface {
|
||||||
|
Do(ctx context.Context, method, path string, headers map[string]string,
|
||||||
|
requestBody any, responseBody any, visitors ...func(*http.Request) error) error
|
||||||
|
}
|
||||||
|
|
||||||
// DbfsClient implements the [Filer] interface for the DBFS backend.
|
// DbfsClient implements the [Filer] interface for the DBFS backend.
|
||||||
type DbfsClient struct {
|
type DbfsClient struct {
|
||||||
workspaceClient *databricks.WorkspaceClient
|
workspaceClient *databricks.WorkspaceClient
|
||||||
|
|
||||||
|
apiClient databricksClient
|
||||||
|
|
||||||
// File operations will be relative to this path.
|
// File operations will be relative to this path.
|
||||||
root WorkspaceRootPath
|
root WorkspaceRootPath
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
|
func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
|
||||||
|
apiClient, err := client.New(w.Config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create API client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &DbfsClient{
|
return &DbfsClient{
|
||||||
workspaceClient: w,
|
workspaceClient: w,
|
||||||
|
apiClient: apiClient,
|
||||||
|
|
||||||
root: NewWorkspaceRootPath(root),
|
root: NewWorkspaceRootPath(root),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *DbfsClient) uploadUsingDbfsPutApi(ctx context.Context, path string, overwrite bool, file *os.File) error {
|
func (w *DbfsClient) uploadUsingDbfsPutApi(ctx context.Context, path string, overwrite bool, file *os.File) error {
|
||||||
apiClient, err := client.New(w.workspaceClient.Config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create API client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
overwriteField := "False"
|
overwriteField := "False"
|
||||||
if overwrite {
|
if overwrite {
|
||||||
overwriteField = "True"
|
overwriteField = "True"
|
||||||
|
@ -97,7 +106,7 @@ func (w *DbfsClient) uploadUsingDbfsPutApi(ctx context.Context, path string, ove
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
writer := multipart.NewWriter(buf)
|
writer := multipart.NewWriter(buf)
|
||||||
err = writer.WriteField("path", path)
|
err := writer.WriteField("path", path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -122,7 +131,7 @@ func (w *DbfsClient) uploadUsingDbfsPutApi(ctx context.Context, path string, ove
|
||||||
|
|
||||||
// Request bodies of Content-Type multipart/form-data must are not supported by
|
// Request bodies of Content-Type multipart/form-data must are not supported by
|
||||||
// the Go SDK directly for DBFS. So we use the Do method directly.
|
// the Go SDK directly for DBFS. So we use the Do method directly.
|
||||||
return apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{
|
return w.apiClient.Do(ctx, http.MethodPost, "/api/2.0/dbfs/put", map[string]string{
|
||||||
"Content-Type": writer.FormDataContentType(),
|
"Content-Type": writer.FormDataContentType(),
|
||||||
}, buf.Bytes(), nil)
|
}, buf.Bytes(), nil)
|
||||||
}
|
}
|
||||||
|
@ -161,7 +170,7 @@ func (w *DbfsClient) uploadUsingDbfsStreamingApi(ctx context.Context, path strin
|
||||||
// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded
|
// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded
|
||||||
// using the /dbfs/put API. If the file is larger than this limit, the streaming
|
// using the /dbfs/put API. If the file is larger than this limit, the streaming
|
||||||
// API (/dbfs/create and /dbfs/add-block) will be used instead.
|
// API (/dbfs/create and /dbfs/add-block) will be used instead.
|
||||||
var MaxUploadLimitForPutApi int64 = 2 * 1024 * 1024
|
var MaxDbfsUploadLimitForPutApi int64 = 2 * 1024 * 1024
|
||||||
|
|
||||||
func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error {
|
func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error {
|
||||||
absPath, err := w.root.Join(name)
|
absPath, err := w.root.Join(name)
|
||||||
|
@ -211,7 +220,7 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the source is a local file, but is too large then we'll use the streaming API endpoint.
|
// If the source is a local file, but is too large then we'll use the streaming API endpoint.
|
||||||
if stat.Size() > MaxUploadLimitForPutApi {
|
if stat.Size() > MaxDbfsUploadLimitForPutApi {
|
||||||
return w.uploadUsingDbfsStreamingApi(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
|
return w.uploadUsingDbfsStreamingApi(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue