mirror of https://github.com/databricks/cli.git
Detect Jupyter notebook files (#219)
Files with extension `.ipynb` are imported are Jupyter notebooks. This code detects 1) if the file is a valid Jupyter notebook and 2) the Databricks specific language it contains.
This commit is contained in:
parent
f93b541b63
commit
9d3a0da073
|
@ -54,6 +54,8 @@ func Detect(path string) (notebook bool, language workspace.Language, err error)
|
|||
case ".sql":
|
||||
header = "-- Databricks notebook source"
|
||||
language = workspace.LanguageSql
|
||||
case ".ipynb":
|
||||
return DetectJupyter(path)
|
||||
default:
|
||||
return false, "", nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
package notebook
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
)
|
||||
|
||||
type jupyterDatabricksMetadata struct {
|
||||
Language string `json:"language"`
|
||||
NotebookName string `json:"notebookName"`
|
||||
}
|
||||
|
||||
// See https://nbformat.readthedocs.io/en/latest/format_description.html#top-level-structure.
|
||||
type jupyter struct {
|
||||
Cells []json.RawMessage `json:"cells,omitempty"`
|
||||
Metadata map[string]json.RawMessage `json:"metadata,omitempty"`
|
||||
NbFormatMajor int `json:"nbformat"`
|
||||
NbFormatMinor int `json:"nbformat_minor"`
|
||||
}
|
||||
|
||||
// resolveLanguage looks at Databricks specific metadata to figure out the language of the notebook.
|
||||
func resolveLanguage(nb *jupyter) workspace.Language {
|
||||
if nb.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
raw, ok := nb.Metadata["application/vnd.databricks.v1+notebook"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
var metadata jupyterDatabricksMetadata
|
||||
err := json.Unmarshal(raw, &metadata)
|
||||
if err != nil {
|
||||
// Fine to swallow error. The file must be malformed.
|
||||
return ""
|
||||
}
|
||||
|
||||
switch metadata.Language {
|
||||
case "python":
|
||||
return workspace.LanguagePython
|
||||
case "r":
|
||||
return workspace.LanguageR
|
||||
case "scala":
|
||||
return workspace.LanguageScala
|
||||
case "sql":
|
||||
return workspace.LanguageSql
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// DetectJupyter returns whether the file at path is a valid Jupyter notebook.
|
||||
// We assume it is valid if we can read it as JSON and see a couple expected fields.
|
||||
// If we cannot, importing into the workspace will always fail, so we also return an error.
|
||||
func DetectJupyter(path string) (notebook bool, language workspace.Language, err error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
var nb jupyter
|
||||
dec := json.NewDecoder(f)
|
||||
err = dec.Decode(&nb)
|
||||
if err != nil {
|
||||
return false, "", fmt.Errorf("%s: error loading Jupyter notebook file: %w", path, err)
|
||||
}
|
||||
|
||||
// Not a Jupyter notebook if the cells or metadata fields aren't defined.
|
||||
if nb.Cells == nil || nb.Metadata == nil {
|
||||
return false, "", fmt.Errorf("%s: invalid Jupyter notebook file", path)
|
||||
}
|
||||
|
||||
// Major version must be at least 4.
|
||||
if nb.NbFormatMajor < 4 {
|
||||
return false, "", fmt.Errorf("%s: unsupported Jupyter notebook version: %d", path, nb.NbFormatMajor)
|
||||
}
|
||||
|
||||
return true, resolveLanguage(&nb), nil
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package notebook
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDetectJupyter(t *testing.T) {
|
||||
var nb bool
|
||||
var lang workspace.Language
|
||||
var err error
|
||||
|
||||
nb, lang, err = DetectJupyter("./testdata/py_ipynb.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguagePython, lang)
|
||||
|
||||
nb, lang, err = DetectJupyter("./testdata/r_ipynb.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageR, lang)
|
||||
|
||||
nb, lang, err = DetectJupyter("./testdata/scala_ipynb.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageScala, lang)
|
||||
|
||||
nb, lang, err = DetectJupyter("./testdata/sql_ipynb.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageSql, lang)
|
||||
}
|
||||
|
||||
func TestDetectJupyterInvalidJSON(t *testing.T) {
|
||||
// Create garbage file.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "file.ipynb")
|
||||
buf := make([]byte, 128)
|
||||
err := os.WriteFile(path, buf, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Garbage contents means not a notebook.
|
||||
nb, _, err := DetectJupyter(path)
|
||||
require.ErrorContains(t, err, "error loading Jupyter notebook file")
|
||||
assert.False(t, nb)
|
||||
}
|
||||
|
||||
func TestDetectJupyterNoCells(t *testing.T) {
|
||||
// Create empty JSON file.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "file.ipynb")
|
||||
buf := []byte("{}")
|
||||
err := os.WriteFile(path, buf, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Garbage contents means not a notebook.
|
||||
nb, _, err := DetectJupyter(path)
|
||||
require.ErrorContains(t, err, "invalid Jupyter notebook file")
|
||||
assert.False(t, nb)
|
||||
}
|
||||
|
||||
func TestDetectJupyterOldVersion(t *testing.T) {
|
||||
// Create empty JSON file.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "file.ipynb")
|
||||
buf := []byte(`{ "cells": [], "metadata": {}, "nbformat": 3 }`)
|
||||
err := os.WriteFile(path, buf, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Garbage contents means not a notebook.
|
||||
nb, _, err := DetectJupyter(path)
|
||||
require.ErrorContains(t, err, "unsupported Jupyter notebook version")
|
||||
assert.False(t, nb)
|
||||
}
|
|
@ -10,27 +10,27 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDetect(t *testing.T) {
|
||||
func TestDetectSource(t *testing.T) {
|
||||
var nb bool
|
||||
var lang workspace.Language
|
||||
var err error
|
||||
|
||||
nb, lang, err = Detect("./testdata/py.py")
|
||||
nb, lang, err = Detect("./testdata/py_source.py")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguagePython, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/r.r")
|
||||
nb, lang, err = Detect("./testdata/r_source.r")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageR, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/scala.scala")
|
||||
nb, lang, err = Detect("./testdata/scala_source.scala")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageScala, lang)
|
||||
|
||||
nb, lang, err = Detect("./testdata/sql.sql")
|
||||
nb, lang, err = Detect("./testdata/sql_source.sql")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguageSql, lang)
|
||||
|
@ -41,6 +41,13 @@ func TestDetect(t *testing.T) {
|
|||
assert.Equal(t, workspace.Language(""), lang)
|
||||
}
|
||||
|
||||
func TestDetectCallsDetectJupyter(t *testing.T) {
|
||||
nb, lang, err := Detect("./testdata/py_ipynb.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb)
|
||||
assert.Equal(t, workspace.LanguagePython, lang)
|
||||
}
|
||||
|
||||
func TestDetectUnknownExtension(t *testing.T) {
|
||||
nb, _, err := Detect("./testdata/doesntexist.foobar")
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue