diff --git a/libs/notebook/detect.go b/libs/notebook/detect.go index 19b580cc..ad751fa0 100644 --- a/libs/notebook/detect.go +++ b/libs/notebook/detect.go @@ -54,6 +54,8 @@ func Detect(path string) (notebook bool, language workspace.Language, err error) case ".sql": header = "-- Databricks notebook source" language = workspace.LanguageSql + case ".ipynb": + return DetectJupyter(path) default: return false, "", nil } diff --git a/libs/notebook/detect_jupyter.go b/libs/notebook/detect_jupyter.go new file mode 100644 index 00000000..7d96763c --- /dev/null +++ b/libs/notebook/detect_jupyter.go @@ -0,0 +1,85 @@ +package notebook + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/databricks/databricks-sdk-go/service/workspace" +) + +type jupyterDatabricksMetadata struct { + Language string `json:"language"` + NotebookName string `json:"notebookName"` +} + +// See https://nbformat.readthedocs.io/en/latest/format_description.html#top-level-structure. +type jupyter struct { + Cells []json.RawMessage `json:"cells,omitempty"` + Metadata map[string]json.RawMessage `json:"metadata,omitempty"` + NbFormatMajor int `json:"nbformat"` + NbFormatMinor int `json:"nbformat_minor"` +} + +// resolveLanguage looks at Databricks specific metadata to figure out the language of the notebook. +func resolveLanguage(nb *jupyter) workspace.Language { + if nb.Metadata == nil { + return "" + } + + raw, ok := nb.Metadata["application/vnd.databricks.v1+notebook"] + if !ok { + return "" + } + + var metadata jupyterDatabricksMetadata + err := json.Unmarshal(raw, &metadata) + if err != nil { + // Fine to swallow error. The file must be malformed. + return "" + } + + switch metadata.Language { + case "python": + return workspace.LanguagePython + case "r": + return workspace.LanguageR + case "scala": + return workspace.LanguageScala + case "sql": + return workspace.LanguageSql + default: + return "" + } +} + +// DetectJupyter returns whether the file at path is a valid Jupyter notebook. +// We assume it is valid if we can read it as JSON and see a couple expected fields. +// If we cannot, importing into the workspace will always fail, so we also return an error. +func DetectJupyter(path string) (notebook bool, language workspace.Language, err error) { + f, err := os.Open(path) + if err != nil { + return false, "", err + } + + defer f.Close() + + var nb jupyter + dec := json.NewDecoder(f) + err = dec.Decode(&nb) + if err != nil { + return false, "", fmt.Errorf("%s: error loading Jupyter notebook file: %w", path, err) + } + + // Not a Jupyter notebook if the cells or metadata fields aren't defined. + if nb.Cells == nil || nb.Metadata == nil { + return false, "", fmt.Errorf("%s: invalid Jupyter notebook file", path) + } + + // Major version must be at least 4. + if nb.NbFormatMajor < 4 { + return false, "", fmt.Errorf("%s: unsupported Jupyter notebook version: %d", path, nb.NbFormatMajor) + } + + return true, resolveLanguage(&nb), nil +} diff --git a/libs/notebook/detect_jupyter_test.go b/libs/notebook/detect_jupyter_test.go new file mode 100644 index 00000000..4ff2aeff --- /dev/null +++ b/libs/notebook/detect_jupyter_test.go @@ -0,0 +1,79 @@ +package notebook + +import ( + "os" + "path/filepath" + "testing" + + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetectJupyter(t *testing.T) { + var nb bool + var lang workspace.Language + var err error + + nb, lang, err = DetectJupyter("./testdata/py_ipynb.ipynb") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguagePython, lang) + + nb, lang, err = DetectJupyter("./testdata/r_ipynb.ipynb") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageR, lang) + + nb, lang, err = DetectJupyter("./testdata/scala_ipynb.ipynb") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageScala, lang) + + nb, lang, err = DetectJupyter("./testdata/sql_ipynb.ipynb") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguageSql, lang) +} + +func TestDetectJupyterInvalidJSON(t *testing.T) { + // Create garbage file. + dir := t.TempDir() + path := filepath.Join(dir, "file.ipynb") + buf := make([]byte, 128) + err := os.WriteFile(path, buf, 0644) + require.NoError(t, err) + + // Garbage contents means not a notebook. + nb, _, err := DetectJupyter(path) + require.ErrorContains(t, err, "error loading Jupyter notebook file") + assert.False(t, nb) +} + +func TestDetectJupyterNoCells(t *testing.T) { + // Create empty JSON file. + dir := t.TempDir() + path := filepath.Join(dir, "file.ipynb") + buf := []byte("{}") + err := os.WriteFile(path, buf, 0644) + require.NoError(t, err) + + // Garbage contents means not a notebook. + nb, _, err := DetectJupyter(path) + require.ErrorContains(t, err, "invalid Jupyter notebook file") + assert.False(t, nb) +} + +func TestDetectJupyterOldVersion(t *testing.T) { + // Create empty JSON file. + dir := t.TempDir() + path := filepath.Join(dir, "file.ipynb") + buf := []byte(`{ "cells": [], "metadata": {}, "nbformat": 3 }`) + err := os.WriteFile(path, buf, 0644) + require.NoError(t, err) + + // Garbage contents means not a notebook. + nb, _, err := DetectJupyter(path) + require.ErrorContains(t, err, "unsupported Jupyter notebook version") + assert.False(t, nb) +} diff --git a/libs/notebook/detect_test.go b/libs/notebook/detect_test.go index ba3cdaba..74c13e70 100644 --- a/libs/notebook/detect_test.go +++ b/libs/notebook/detect_test.go @@ -10,27 +10,27 @@ import ( "github.com/stretchr/testify/require" ) -func TestDetect(t *testing.T) { +func TestDetectSource(t *testing.T) { var nb bool var lang workspace.Language var err error - nb, lang, err = Detect("./testdata/py.py") + nb, lang, err = Detect("./testdata/py_source.py") require.NoError(t, err) assert.True(t, nb) assert.Equal(t, workspace.LanguagePython, lang) - nb, lang, err = Detect("./testdata/r.r") + nb, lang, err = Detect("./testdata/r_source.r") require.NoError(t, err) assert.True(t, nb) assert.Equal(t, workspace.LanguageR, lang) - nb, lang, err = Detect("./testdata/scala.scala") + nb, lang, err = Detect("./testdata/scala_source.scala") require.NoError(t, err) assert.True(t, nb) assert.Equal(t, workspace.LanguageScala, lang) - nb, lang, err = Detect("./testdata/sql.sql") + nb, lang, err = Detect("./testdata/sql_source.sql") require.NoError(t, err) assert.True(t, nb) assert.Equal(t, workspace.LanguageSql, lang) @@ -41,6 +41,13 @@ func TestDetect(t *testing.T) { assert.Equal(t, workspace.Language(""), lang) } +func TestDetectCallsDetectJupyter(t *testing.T) { + nb, lang, err := Detect("./testdata/py_ipynb.ipynb") + require.NoError(t, err) + assert.True(t, nb) + assert.Equal(t, workspace.LanguagePython, lang) +} + func TestDetectUnknownExtension(t *testing.T) { nb, _, err := Detect("./testdata/doesntexist.foobar") require.NoError(t, err) diff --git a/libs/notebook/testdata/py.py b/libs/notebook/testdata/py_source.py similarity index 100% rename from libs/notebook/testdata/py.py rename to libs/notebook/testdata/py_source.py diff --git a/libs/notebook/testdata/r.r b/libs/notebook/testdata/r_source.r similarity index 100% rename from libs/notebook/testdata/r.r rename to libs/notebook/testdata/r_source.r diff --git a/libs/notebook/testdata/scala.scala b/libs/notebook/testdata/scala_source.scala similarity index 100% rename from libs/notebook/testdata/scala.scala rename to libs/notebook/testdata/scala_source.scala diff --git a/libs/notebook/testdata/sql.sql b/libs/notebook/testdata/sql_source.sql similarity index 100% rename from libs/notebook/testdata/sql.sql rename to libs/notebook/testdata/sql_source.sql