Detect Jupyter notebook files (#219)

Files with extension `.ipynb` are imported are Jupyter notebooks.

This code detects 1) if the file is a valid Jupyter notebook and 2) the
Databricks specific language it contains.
This commit is contained in:
Pieter Noordhuis 2023-02-21 13:49:01 +01:00 committed by GitHub
parent f93b541b63
commit 9d3a0da073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 5 deletions

View File

@ -54,6 +54,8 @@ func Detect(path string) (notebook bool, language workspace.Language, err error)
case ".sql": case ".sql":
header = "-- Databricks notebook source" header = "-- Databricks notebook source"
language = workspace.LanguageSql language = workspace.LanguageSql
case ".ipynb":
return DetectJupyter(path)
default: default:
return false, "", nil return false, "", nil
} }

View File

@ -0,0 +1,85 @@
package notebook
import (
"encoding/json"
"fmt"
"os"
"github.com/databricks/databricks-sdk-go/service/workspace"
)
type jupyterDatabricksMetadata struct {
Language string `json:"language"`
NotebookName string `json:"notebookName"`
}
// See https://nbformat.readthedocs.io/en/latest/format_description.html#top-level-structure.
type jupyter struct {
Cells []json.RawMessage `json:"cells,omitempty"`
Metadata map[string]json.RawMessage `json:"metadata,omitempty"`
NbFormatMajor int `json:"nbformat"`
NbFormatMinor int `json:"nbformat_minor"`
}
// resolveLanguage looks at Databricks specific metadata to figure out the language of the notebook.
func resolveLanguage(nb *jupyter) workspace.Language {
if nb.Metadata == nil {
return ""
}
raw, ok := nb.Metadata["application/vnd.databricks.v1+notebook"]
if !ok {
return ""
}
var metadata jupyterDatabricksMetadata
err := json.Unmarshal(raw, &metadata)
if err != nil {
// Fine to swallow error. The file must be malformed.
return ""
}
switch metadata.Language {
case "python":
return workspace.LanguagePython
case "r":
return workspace.LanguageR
case "scala":
return workspace.LanguageScala
case "sql":
return workspace.LanguageSql
default:
return ""
}
}
// DetectJupyter returns whether the file at path is a valid Jupyter notebook.
// We assume it is valid if we can read it as JSON and see a couple expected fields.
// If we cannot, importing into the workspace will always fail, so we also return an error.
func DetectJupyter(path string) (notebook bool, language workspace.Language, err error) {
f, err := os.Open(path)
if err != nil {
return false, "", err
}
defer f.Close()
var nb jupyter
dec := json.NewDecoder(f)
err = dec.Decode(&nb)
if err != nil {
return false, "", fmt.Errorf("%s: error loading Jupyter notebook file: %w", path, err)
}
// Not a Jupyter notebook if the cells or metadata fields aren't defined.
if nb.Cells == nil || nb.Metadata == nil {
return false, "", fmt.Errorf("%s: invalid Jupyter notebook file", path)
}
// Major version must be at least 4.
if nb.NbFormatMajor < 4 {
return false, "", fmt.Errorf("%s: unsupported Jupyter notebook version: %d", path, nb.NbFormatMajor)
}
return true, resolveLanguage(&nb), nil
}

View File

@ -0,0 +1,79 @@
package notebook
import (
"os"
"path/filepath"
"testing"
"github.com/databricks/databricks-sdk-go/service/workspace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDetectJupyter(t *testing.T) {
var nb bool
var lang workspace.Language
var err error
nb, lang, err = DetectJupyter("./testdata/py_ipynb.ipynb")
require.NoError(t, err)
assert.True(t, nb)
assert.Equal(t, workspace.LanguagePython, lang)
nb, lang, err = DetectJupyter("./testdata/r_ipynb.ipynb")
require.NoError(t, err)
assert.True(t, nb)
assert.Equal(t, workspace.LanguageR, lang)
nb, lang, err = DetectJupyter("./testdata/scala_ipynb.ipynb")
require.NoError(t, err)
assert.True(t, nb)
assert.Equal(t, workspace.LanguageScala, lang)
nb, lang, err = DetectJupyter("./testdata/sql_ipynb.ipynb")
require.NoError(t, err)
assert.True(t, nb)
assert.Equal(t, workspace.LanguageSql, lang)
}
func TestDetectJupyterInvalidJSON(t *testing.T) {
// Create garbage file.
dir := t.TempDir()
path := filepath.Join(dir, "file.ipynb")
buf := make([]byte, 128)
err := os.WriteFile(path, buf, 0644)
require.NoError(t, err)
// Garbage contents means not a notebook.
nb, _, err := DetectJupyter(path)
require.ErrorContains(t, err, "error loading Jupyter notebook file")
assert.False(t, nb)
}
func TestDetectJupyterNoCells(t *testing.T) {
// Create empty JSON file.
dir := t.TempDir()
path := filepath.Join(dir, "file.ipynb")
buf := []byte("{}")
err := os.WriteFile(path, buf, 0644)
require.NoError(t, err)
// Garbage contents means not a notebook.
nb, _, err := DetectJupyter(path)
require.ErrorContains(t, err, "invalid Jupyter notebook file")
assert.False(t, nb)
}
func TestDetectJupyterOldVersion(t *testing.T) {
// Create empty JSON file.
dir := t.TempDir()
path := filepath.Join(dir, "file.ipynb")
buf := []byte(`{ "cells": [], "metadata": {}, "nbformat": 3 }`)
err := os.WriteFile(path, buf, 0644)
require.NoError(t, err)
// Garbage contents means not a notebook.
nb, _, err := DetectJupyter(path)
require.ErrorContains(t, err, "unsupported Jupyter notebook version")
assert.False(t, nb)
}

View File

@ -10,27 +10,27 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestDetect(t *testing.T) { func TestDetectSource(t *testing.T) {
var nb bool var nb bool
var lang workspace.Language var lang workspace.Language
var err error var err error
nb, lang, err = Detect("./testdata/py.py") nb, lang, err = Detect("./testdata/py_source.py")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, nb) assert.True(t, nb)
assert.Equal(t, workspace.LanguagePython, lang) assert.Equal(t, workspace.LanguagePython, lang)
nb, lang, err = Detect("./testdata/r.r") nb, lang, err = Detect("./testdata/r_source.r")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, nb) assert.True(t, nb)
assert.Equal(t, workspace.LanguageR, lang) assert.Equal(t, workspace.LanguageR, lang)
nb, lang, err = Detect("./testdata/scala.scala") nb, lang, err = Detect("./testdata/scala_source.scala")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, nb) assert.True(t, nb)
assert.Equal(t, workspace.LanguageScala, lang) assert.Equal(t, workspace.LanguageScala, lang)
nb, lang, err = Detect("./testdata/sql.sql") nb, lang, err = Detect("./testdata/sql_source.sql")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, nb) assert.True(t, nb)
assert.Equal(t, workspace.LanguageSql, lang) assert.Equal(t, workspace.LanguageSql, lang)
@ -41,6 +41,13 @@ func TestDetect(t *testing.T) {
assert.Equal(t, workspace.Language(""), lang) assert.Equal(t, workspace.Language(""), lang)
} }
func TestDetectCallsDetectJupyter(t *testing.T) {
nb, lang, err := Detect("./testdata/py_ipynb.ipynb")
require.NoError(t, err)
assert.True(t, nb)
assert.Equal(t, workspace.LanguagePython, lang)
}
func TestDetectUnknownExtension(t *testing.T) { func TestDetectUnknownExtension(t *testing.T) {
nb, _, err := Detect("./testdata/doesntexist.foobar") nb, _, err := Detect("./testdata/doesntexist.foobar")
require.NoError(t, err) require.NoError(t, err)