mirror of https://github.com/databricks/cli.git
Move notebook detection logic to package (#206)
This commit is contained in:
parent
8c1b620b17
commit
58950ce507
|
@ -0,0 +1,73 @@
|
|||
package notebook
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
)
|
||||
|
||||
// Maximum length in bytes of the notebook header.
|
||||
const headerLength = 32
|
||||
|
||||
// readHeader reads the first N bytes from a file.
|
||||
func readHeader(path string) ([]byte, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
// Scan header line with some padding.
|
||||
var buf = make([]byte, headerLength)
|
||||
n, err := f.Read([]byte(buf))
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Trim buffer to actual read bytes.
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
// Detect returns whether the file at path is a Databricks notebook.
|
||||
// If it is, it returns the notebook language.
|
||||
func Detect(path string) (notebook bool, language workspace.Language, err error) {
|
||||
header := ""
|
||||
|
||||
// Determine which header to expect based on filename extension.
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
switch ext {
|
||||
case ".py":
|
||||
header = `# Databricks notebook source`
|
||||
language = workspace.LanguagePython
|
||||
case ".r":
|
||||
header = `# Databricks notebook source`
|
||||
language = workspace.LanguageR
|
||||
case ".scala":
|
||||
header = "// Databricks notebook source"
|
||||
language = workspace.LanguageScala
|
||||
case ".sql":
|
||||
header = "-- Databricks notebook source"
|
||||
language = workspace.LanguageSql
|
||||
default:
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
buf, err := readHeader(path)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(buf))
|
||||
scanner.Scan()
|
||||
if scanner.Text() != header {
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
return true, language, nil
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package notebook
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDetect(t *testing.T) {
|
||||
var nb bool
|
||||
var lang workspace.Language
|
||||
var err error
|
||||
|
||||
nb, lang, err = Detect("./testdata/py.py")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguagePython, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/r.r")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageR, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/scala.scala")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageScala, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/sql.sql")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageSql, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/txt.txt")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nb)
|
||||
assert.Equal(t, workspace.Language(""), lang)
|
||||
}
|
||||
|
||||
func TestDetectUnknownExtension(t *testing.T) {
|
||||
nb, _, err := Detect("./testdata/doesntexist.foobar")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nb)
|
||||
}
|
||||
|
||||
func TestDetectNoExtension(t *testing.T) {
|
||||
nb, _, err := Detect("./testdata/doesntexist")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nb)
|
||||
}
|
||||
|
||||
func TestDetectFileDoesNotExists(t *testing.T) {
|
||||
_, _, err := Detect("./testdata/doesntexist.py")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDetectEmptyFile(t *testing.T) {
|
||||
// Create empty file.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "file.py")
|
||||
err := os.WriteFile(path, nil, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No contents means not a notebook.
|
||||
nb, _, err := Detect(path)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nb)
|
||||
}
|
||||
|
||||
func TestDetectFileWithLongHeader(t *testing.T) {
|
||||
// Create 128kb garbage file.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "file.py")
|
||||
buf := make([]byte, 128*1024)
|
||||
err := os.WriteFile(path, buf, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Garbage contents means not a notebook.
|
||||
nb, _, err := Detect(path)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nb)
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
# Databricks notebook source
|
||||
# hello world
|
|
@ -0,0 +1,2 @@
|
|||
# Databricks notebook source
|
||||
# hello world
|
|
@ -0,0 +1,2 @@
|
|||
// Databricks notebook source
|
||||
// hello world
|
|
@ -0,0 +1,2 @@
|
|||
-- Databricks notebook source
|
||||
-- hello world
|
|
@ -0,0 +1 @@
|
|||
hello world
|
|
@ -1,14 +1,12 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +14,7 @@ import (
|
|||
"encoding/hex"
|
||||
|
||||
"github.com/databricks/bricks/libs/fileset"
|
||||
"github.com/databricks/bricks/libs/notebook"
|
||||
)
|
||||
|
||||
// Bump it up every time a potentially breaking change is made to the snapshot schema
|
||||
|
@ -165,32 +164,6 @@ func loadOrNewSnapshot(opts *SyncOptions) (*Snapshot, error) {
|
|||
return snapshot, nil
|
||||
}
|
||||
|
||||
func getNotebookDetails(path string) (isNotebook bool, typeOfNotebook string, err error) {
|
||||
isNotebook = false
|
||||
typeOfNotebook = ""
|
||||
|
||||
isPythonFile, err := regexp.Match(`\.py$`, []byte(path))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if isPythonFile {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
scanner := bufio.NewScanner(f)
|
||||
ok := scanner.Scan()
|
||||
if !ok {
|
||||
return false, "", scanner.Err()
|
||||
}
|
||||
// A python file is a notebook if it starts with the following magic string
|
||||
isNotebook = strings.Contains(scanner.Text(), "# Databricks notebook source")
|
||||
return isNotebook, "PYTHON", nil
|
||||
}
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
func (s *Snapshot) diff(all []fileset.File) (change diff, err error) {
|
||||
currentFilenames := map[string]bool{}
|
||||
lastModifiedTimes := s.LastUpdatedTimes
|
||||
|
@ -213,15 +186,16 @@ func (s *Snapshot) diff(all []fileset.File) (change diff, err error) {
|
|||
change.put = append(change.put, unixFileName)
|
||||
|
||||
// get file metadata about whether it's a notebook
|
||||
isNotebook, typeOfNotebook, err := getNotebookDetails(f.Absolute)
|
||||
isNotebook, _, err := notebook.Detect(f.Absolute)
|
||||
if err != nil {
|
||||
return change, err
|
||||
}
|
||||
|
||||
// strip `.py` for python notebooks
|
||||
// Strip extension for notebooks.
|
||||
remoteName := unixFileName
|
||||
if isNotebook && typeOfNotebook == "PYTHON" {
|
||||
remoteName = strings.TrimSuffix(remoteName, `.py`)
|
||||
if isNotebook {
|
||||
ext := filepath.Ext(remoteName)
|
||||
remoteName = strings.TrimSuffix(remoteName, ext)
|
||||
}
|
||||
|
||||
// If the remote handle of a file changes, we want to delete the old
|
||||
|
|
Loading…
Reference in New Issue