diff --git a/libs/notebook/detect.go b/libs/notebook/detect.go new file mode 100644 index 00000000..19b580cc --- /dev/null +++ b/libs/notebook/detect.go @@ -0,0 +1,73 @@ +package notebook + +import ( + "bufio" + "bytes" + "io" + "os" + "path/filepath" + "strings" + + "github.com/databricks/databricks-sdk-go/service/workspace" +) + +// Maximum length in bytes of the notebook header. +const headerLength = 32 + +// readHeader reads the first N bytes from a file. +func readHeader(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + defer f.Close() + + // Scan header line with some padding. + var buf = make([]byte, headerLength) + n, err := f.Read([]byte(buf)) + if err != nil && err != io.EOF { + return nil, err + } + + // Trim buffer to actual read bytes. + return buf[:n], nil +} + +// Detect returns whether the file at path is a Databricks notebook. +// If it is, it returns the notebook language. +func Detect(path string) (notebook bool, language workspace.Language, err error) { + header := "" + + // Determine which header to expect based on filename extension. + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".py": + header = `# Databricks notebook source` + language = workspace.LanguagePython + case ".r": + header = `# Databricks notebook source` + language = workspace.LanguageR + case ".scala": + header = "// Databricks notebook source" + language = workspace.LanguageScala + case ".sql": + header = "-- Databricks notebook source" + language = workspace.LanguageSql + default: + return false, "", nil + } + + buf, err := readHeader(path) + if err != nil { + return false, "", err + } + + scanner := bufio.NewScanner(bytes.NewReader(buf)) + scanner.Scan() + if scanner.Text() != header { + return false, "", nil + } + + return true, language, nil +} diff --git a/libs/notebook/detect_test.go b/libs/notebook/detect_test.go new file mode 100644 index 00000000..ba3cdaba --- /dev/null +++ b/libs/notebook/detect_test.go @@ -0,0 +1,86 @@ +package notebook + +import ( + "os" + "path/filepath" + "testing" + + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetect(t *testing.T) { + var nb bool + var lang workspace.Language + var err error + + nb, lang, err = Detect("./testdata/py.py") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguagePython, lang) + + nb, lang, err = Detect("./testdata/r.r") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageR, lang) + + nb, lang, err = Detect("./testdata/scala.scala") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageScala, lang) + + nb, lang, err = Detect("./testdata/sql.sql") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageSql, lang) + + nb, lang, err = Detect("./testdata/txt.txt") + require.NoError(t, err) + assert.False(t, nb) + assert.Equal(t, workspace.Language(""), lang) +} + +func TestDetectUnknownExtension(t *testing.T) { + nb, _, err := Detect("./testdata/doesntexist.foobar") + require.NoError(t, err) + assert.False(t, nb) +} + +func TestDetectNoExtension(t *testing.T) { + nb, _, err := Detect("./testdata/doesntexist") + require.NoError(t, err) + assert.False(t, nb) +} + +func TestDetectFileDoesNotExists(t *testing.T) { + _, _, err := Detect("./testdata/doesntexist.py") + require.Error(t, err) +} + +func TestDetectEmptyFile(t *testing.T) { + // Create empty file. + dir := t.TempDir() + path := filepath.Join(dir, "file.py") + err := os.WriteFile(path, nil, 0644) + require.NoError(t, err) + + // No contents means not a notebook. + nb, _, err := Detect(path) + require.NoError(t, err) + assert.False(t, nb) +} + +func TestDetectFileWithLongHeader(t *testing.T) { + // Create 128kb garbage file. + dir := t.TempDir() + path := filepath.Join(dir, "file.py") + buf := make([]byte, 128*1024) + err := os.WriteFile(path, buf, 0644) + require.NoError(t, err) + + // Garbage contents means not a notebook. + nb, _, err := Detect(path) + require.NoError(t, err) + assert.False(t, nb) +} diff --git a/libs/notebook/testdata/py.py b/libs/notebook/testdata/py.py new file mode 100644 index 00000000..5fca3802 --- /dev/null +++ b/libs/notebook/testdata/py.py @@ -0,0 +1,2 @@ +# Databricks notebook source +# hello world diff --git a/libs/notebook/testdata/r.r b/libs/notebook/testdata/r.r new file mode 100644 index 00000000..5fca3802 --- /dev/null +++ b/libs/notebook/testdata/r.r @@ -0,0 +1,2 @@ +# Databricks notebook source +# hello world diff --git a/libs/notebook/testdata/scala.scala b/libs/notebook/testdata/scala.scala new file mode 100644 index 00000000..1d338e8a --- /dev/null +++ b/libs/notebook/testdata/scala.scala @@ -0,0 +1,2 @@ +// Databricks notebook source +// hello world diff --git a/libs/notebook/testdata/sql.sql b/libs/notebook/testdata/sql.sql new file mode 100644 index 00000000..442e46c4 --- /dev/null +++ b/libs/notebook/testdata/sql.sql @@ -0,0 +1,2 @@ +-- Databricks notebook source +-- hello world diff --git a/libs/notebook/testdata/txt.txt b/libs/notebook/testdata/txt.txt new file mode 100644 index 00000000..3b18e512 --- /dev/null +++ b/libs/notebook/testdata/txt.txt @@ -0,0 +1 @@ +hello world diff --git a/libs/sync/snapshot.go b/libs/sync/snapshot.go index d77e41d2..96c04963 100644 --- a/libs/sync/snapshot.go +++ b/libs/sync/snapshot.go @@ -1,14 +1,12 @@ package sync import ( - "bufio" "context" "encoding/json" "fmt" "log" "os" "path/filepath" - "regexp" "strings" "time" @@ -16,6 +14,7 @@ import ( "encoding/hex" "github.com/databricks/bricks/libs/fileset" + "github.com/databricks/bricks/libs/notebook" ) // Bump it up every time a potentially breaking change is made to the snapshot schema @@ -165,32 +164,6 @@ func loadOrNewSnapshot(opts *SyncOptions) (*Snapshot, error) { return snapshot, nil } -func getNotebookDetails(path string) (isNotebook bool, typeOfNotebook string, err error) { - isNotebook = false - typeOfNotebook = "" - - isPythonFile, err := regexp.Match(`\.py$`, []byte(path)) - if err != nil { - return - } - if isPythonFile { - f, err := os.Open(path) - if err != nil { - return false, "", err - } - defer f.Close() - scanner := bufio.NewScanner(f) - ok := scanner.Scan() - if !ok { - return false, "", scanner.Err() - } - // A python file is a notebook if it starts with the following magic string - isNotebook = strings.Contains(scanner.Text(), "# Databricks notebook source") - return isNotebook, "PYTHON", nil - } - return false, "", nil -} - func (s *Snapshot) diff(all []fileset.File) (change diff, err error) { currentFilenames := map[string]bool{} lastModifiedTimes := s.LastUpdatedTimes @@ -213,15 +186,16 @@ func (s *Snapshot) diff(all []fileset.File) (change diff, err error) { change.put = append(change.put, unixFileName) // get file metadata about whether it's a notebook - isNotebook, typeOfNotebook, err := getNotebookDetails(f.Absolute) + isNotebook, _, err := notebook.Detect(f.Absolute) if err != nil { return change, err } - // strip `.py` for python notebooks + // Strip extension for notebooks. remoteName := unixFileName - if isNotebook && typeOfNotebook == "PYTHON" { - remoteName = strings.TrimSuffix(remoteName, `.py`) + if isNotebook { + ext := filepath.Ext(remoteName) + remoteName = strings.TrimSuffix(remoteName, ext) } // If the remote handle of a file changes, we want to delete the old