mirror of https://github.com/databricks/cli.git
Detect Jupyter notebook files (#219)
Files with extension `.ipynb` are imported are Jupyter notebooks. This code detects 1) if the file is a valid Jupyter notebook and 2) the Databricks specific language it contains.
This commit is contained in:
parent
f93b541b63
commit
9d3a0da073
|
@ -54,6 +54,8 @@ func Detect(path string) (notebook bool, language workspace.Language, err error)
|
||||||
case ".sql":
|
case ".sql":
|
||||||
header = "-- Databricks notebook source"
|
header = "-- Databricks notebook source"
|
||||||
language = workspace.LanguageSql
|
language = workspace.LanguageSql
|
||||||
|
case ".ipynb":
|
||||||
|
return DetectJupyter(path)
|
||||||
default:
|
default:
|
||||||
return false, "", nil
|
return false, "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
package notebook
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||||
|
)
|
||||||
|
|
||||||
|
type jupyterDatabricksMetadata struct {
|
||||||
|
Language string `json:"language"`
|
||||||
|
NotebookName string `json:"notebookName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://nbformat.readthedocs.io/en/latest/format_description.html#top-level-structure.
|
||||||
|
type jupyter struct {
|
||||||
|
Cells []json.RawMessage `json:"cells,omitempty"`
|
||||||
|
Metadata map[string]json.RawMessage `json:"metadata,omitempty"`
|
||||||
|
NbFormatMajor int `json:"nbformat"`
|
||||||
|
NbFormatMinor int `json:"nbformat_minor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveLanguage looks at Databricks specific metadata to figure out the language of the notebook.
|
||||||
|
func resolveLanguage(nb *jupyter) workspace.Language {
|
||||||
|
if nb.Metadata == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, ok := nb.Metadata["application/vnd.databricks.v1+notebook"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var metadata jupyterDatabricksMetadata
|
||||||
|
err := json.Unmarshal(raw, &metadata)
|
||||||
|
if err != nil {
|
||||||
|
// Fine to swallow error. The file must be malformed.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch metadata.Language {
|
||||||
|
case "python":
|
||||||
|
return workspace.LanguagePython
|
||||||
|
case "r":
|
||||||
|
return workspace.LanguageR
|
||||||
|
case "scala":
|
||||||
|
return workspace.LanguageScala
|
||||||
|
case "sql":
|
||||||
|
return workspace.LanguageSql
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectJupyter returns whether the file at path is a valid Jupyter notebook.
|
||||||
|
// We assume it is valid if we can read it as JSON and see a couple expected fields.
|
||||||
|
// If we cannot, importing into the workspace will always fail, so we also return an error.
|
||||||
|
func DetectJupyter(path string) (notebook bool, language workspace.Language, err error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var nb jupyter
|
||||||
|
dec := json.NewDecoder(f)
|
||||||
|
err = dec.Decode(&nb)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", fmt.Errorf("%s: error loading Jupyter notebook file: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a Jupyter notebook if the cells or metadata fields aren't defined.
|
||||||
|
if nb.Cells == nil || nb.Metadata == nil {
|
||||||
|
return false, "", fmt.Errorf("%s: invalid Jupyter notebook file", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Major version must be at least 4.
|
||||||
|
if nb.NbFormatMajor < 4 {
|
||||||
|
return false, "", fmt.Errorf("%s: unsupported Jupyter notebook version: %d", path, nb.NbFormatMajor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, resolveLanguage(&nb), nil
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
package notebook
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDetectJupyter(t *testing.T) {
|
||||||
|
var nb bool
|
||||||
|
var lang workspace.Language
|
||||||
|
var err error
|
||||||
|
|
||||||
|
nb, lang, err = DetectJupyter("./testdata/py_ipynb.ipynb")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb)
|
||||||
|
assert.Equal(t, workspace.LanguagePython, lang)
|
||||||
|
|
||||||
|
nb, lang, err = DetectJupyter("./testdata/r_ipynb.ipynb")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb)
|
||||||
|
assert.Equal(t, workspace.LanguageR, lang)
|
||||||
|
|
||||||
|
nb, lang, err = DetectJupyter("./testdata/scala_ipynb.ipynb")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb)
|
||||||
|
assert.Equal(t, workspace.LanguageScala, lang)
|
||||||
|
|
||||||
|
nb, lang, err = DetectJupyter("./testdata/sql_ipynb.ipynb")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb)
|
||||||
|
assert.Equal(t, workspace.LanguageSql, lang)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectJupyterInvalidJSON(t *testing.T) {
|
||||||
|
// Create garbage file.
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "file.ipynb")
|
||||||
|
buf := make([]byte, 128)
|
||||||
|
err := os.WriteFile(path, buf, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Garbage contents means not a notebook.
|
||||||
|
nb, _, err := DetectJupyter(path)
|
||||||
|
require.ErrorContains(t, err, "error loading Jupyter notebook file")
|
||||||
|
assert.False(t, nb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectJupyterNoCells(t *testing.T) {
|
||||||
|
// Create empty JSON file.
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "file.ipynb")
|
||||||
|
buf := []byte("{}")
|
||||||
|
err := os.WriteFile(path, buf, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Garbage contents means not a notebook.
|
||||||
|
nb, _, err := DetectJupyter(path)
|
||||||
|
require.ErrorContains(t, err, "invalid Jupyter notebook file")
|
||||||
|
assert.False(t, nb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectJupyterOldVersion(t *testing.T) {
|
||||||
|
// Create empty JSON file.
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "file.ipynb")
|
||||||
|
buf := []byte(`{ "cells": [], "metadata": {}, "nbformat": 3 }`)
|
||||||
|
err := os.WriteFile(path, buf, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Garbage contents means not a notebook.
|
||||||
|
nb, _, err := DetectJupyter(path)
|
||||||
|
require.ErrorContains(t, err, "unsupported Jupyter notebook version")
|
||||||
|
assert.False(t, nb)
|
||||||
|
}
|
|
@ -10,27 +10,27 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDetect(t *testing.T) {
|
func TestDetectSource(t *testing.T) {
|
||||||
var nb bool
|
var nb bool
|
||||||
var lang workspace.Language
|
var lang workspace.Language
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
nb, lang, err = Detect("./testdata/py.py")
|
nb, lang, err = Detect("./testdata/py_source.py")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, nb)
|
assert.True(t, nb)
|
||||||
assert.Equal(t, workspace.LanguagePython, lang)
|
assert.Equal(t, workspace.LanguagePython, lang)
|
||||||
|
|
||||||
nb, lang, err = Detect("./testdata/r.r")
|
nb, lang, err = Detect("./testdata/r_source.r")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, nb)
|
assert.True(t, nb)
|
||||||
assert.Equal(t, workspace.LanguageR, lang)
|
assert.Equal(t, workspace.LanguageR, lang)
|
||||||
|
|
||||||
nb, lang, err = Detect("./testdata/scala.scala")
|
nb, lang, err = Detect("./testdata/scala_source.scala")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, nb)
|
assert.True(t, nb)
|
||||||
assert.Equal(t, workspace.LanguageScala, lang)
|
assert.Equal(t, workspace.LanguageScala, lang)
|
||||||
|
|
||||||
nb, lang, err = Detect("./testdata/sql.sql")
|
nb, lang, err = Detect("./testdata/sql_source.sql")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, nb)
|
assert.True(t, nb)
|
||||||
assert.Equal(t, workspace.LanguageSql, lang)
|
assert.Equal(t, workspace.LanguageSql, lang)
|
||||||
|
@ -41,6 +41,13 @@ func TestDetect(t *testing.T) {
|
||||||
assert.Equal(t, workspace.Language(""), lang)
|
assert.Equal(t, workspace.Language(""), lang)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDetectCallsDetectJupyter(t *testing.T) {
|
||||||
|
nb, lang, err := Detect("./testdata/py_ipynb.ipynb")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, nb)
|
||||||
|
assert.Equal(t, workspace.LanguagePython, lang)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDetectUnknownExtension(t *testing.T) {
|
func TestDetectUnknownExtension(t *testing.T) {
|
||||||
nb, _, err := Detect("./testdata/doesntexist.foobar")
|
nb, _, err := Detect("./testdata/doesntexist.foobar")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
Loading…
Reference in New Issue