use env.lookup

This commit is contained in:
Fabian Jakobs 2024-09-03 16:08:28 +02:00
parent 662234fb97
commit 8d78809469
No known key found for this signature in database
1 changed files with 13 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import (
"path/filepath"
"strings"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/workspace"
@ -77,12 +78,14 @@ func (f *copyFile) PersistToDisk() error {
}
defer srcFile.Close()
if runsOnDatabricks() && f.isNotebook() {
ctx := context.Background()
if runsOnDatabricks(ctx) && f.isNotebook() {
content, err := io.ReadAll(srcFile)
if err != nil {
return err
}
return writeNotebook(path, content)
return writeNotebook(ctx, path, content)
} else {
dstFile, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, f.perm)
if err != nil {
@ -119,18 +122,21 @@ func (f *inMemoryFile) PersistToDisk() error {
return err
}
if runsOnDatabricks() && f.isNotebook() {
return writeNotebook(path, f.content)
ctx := context.Background()
if runsOnDatabricks(ctx) && f.isNotebook() {
return writeNotebook(ctx, path, f.content)
} else {
return os.WriteFile(path, f.content, f.perm)
}
}
func runsOnDatabricks() bool {
return os.Getenv("DATABRICKS_RUNTIME_VERSION") != ""
func runsOnDatabricks(ctx context.Context) bool {
_, ok := env.Lookup(ctx, "DATABRICKS_RUNTIME_VERSION")
return ok
}
func writeNotebook(path string, content []byte) error {
func writeNotebook(ctx context.Context, path string, content []byte) error {
if !strings.HasPrefix(path, "/Workspace/") {
return os.WriteFile(path, content, 0644)
} else {
@ -139,8 +145,6 @@ func writeNotebook(path string, content []byte) error {
return err
}
ctx := context.Background()
err = w.Workspace.Import(ctx, workspace.Import{
Format: "AUTO",
Overwrite: false,