add test for sql as well as enum for extensions

This commit is contained in:
Shreyas Goenka 2024-10-21 21:40:43 +02:00
parent cd8cc2c531
commit 20a30f7c6f
No known key found for this signature in database
GPG Key ID: 92A07DF49CCB0622
7 changed files with 172 additions and 83 deletions

View File

@ -71,7 +71,7 @@ func (n *downloader) markNotebookForDownload(ctx context.Context, notebookPath *
ext := notebook.GetExtensionByLanguage(info)
filename := path.Base(*notebookPath) + ext
filename := path.Base(*notebookPath) + string(ext)
targetPath := filepath.Join(n.sourceDir, filename)
n.files[targetPath] = *notebookPath

View File

@ -48,7 +48,7 @@ func (opts exportDirOptions) callback(ctx context.Context, workspaceFiler filer.
return err
}
objectInfo := info.Sys().(workspace.ObjectInfo)
targetPath += notebook.GetExtensionByLanguage(&objectInfo)
targetPath += string(notebook.GetExtensionByLanguage(&objectInfo))
// Skip file if a file already exists in path.
// os.Stat returns a fs.ErrNotExist if a file does not exist at path.

View File

@ -427,8 +427,17 @@ func TestAccFilerWorkspaceNotebookConflict(t *testing.T) {
expected1: "// Databricks notebook source\nprintln(1)",
content2: readFile(t, "testdata/notebooks/scala2.ipynb"),
},
{
name: "sqlJupyterNotebook.ipynb",
nameWithoutExt: "sqlJupyterNotebook",
content1: readFile(t, "testdata/notebooks/sql1.ipynb"),
expected1: "-- Databricks notebook source\nselect 1",
content2: readFile(t, "testdata/notebooks/sql2.ipynb"),
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Upload the notebook
err = f.Write(ctx, tc.name, strings.NewReader(tc.content1))
require.NoError(t, err)
@ -516,8 +525,18 @@ func TestAccFilerWorkspaceNotebookWithOverwriteFlag(t *testing.T) {
content2: readFile(t, "testdata/notebooks/scala2.ipynb"),
expected2: "// Databricks notebook source\nprintln(2)",
},
{
name: "sqlJupyterNotebook.ipynb",
nameWithoutExt: "sqlJupyterNotebook",
content1: readFile(t, "testdata/notebooks/sql1.ipynb"),
expected1: "-- Databricks notebook source\nselect 1",
content2: readFile(t, "testdata/notebooks/sql2.ipynb"),
expected2: "-- Databricks notebook source\nselect 2",
},
} {
t.Run(tcases.name, func(t *testing.T) {
t.Parallel()
// Upload the notebook
err = f.Write(ctx, tcases.name, strings.NewReader(tcases.content1))
require.NoError(t, err)
@ -535,8 +554,6 @@ func TestAccFilerWorkspaceNotebookWithOverwriteFlag(t *testing.T) {
}
}
// TODO: Add a test that the exported file has the right extension / language_info set?
// Required for DABs in the workspace.
func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
t.Parallel()
@ -557,6 +574,7 @@ func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
{"rNb.r", "# Databricks notebook source\nprint('first upload'))"},
{"scala1.ipynb", readFile(t, "testdata/notebooks/scala1.ipynb")},
{"scalaNb.scala", "// Databricks notebook source\n println(\"first upload\"))"},
{"sql1.ipynb", readFile(t, "testdata/notebooks/sql1.ipynb")},
{"sqlNb.sql", "-- Databricks notebook source\n SELECT \"first upload\""},
}
@ -598,6 +616,7 @@ func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
"rNb.r",
"scala1.ipynb",
"scalaNb.scala",
"sql1.ipynb",
"sqlNb.sql",
}, names)
@ -624,6 +643,7 @@ func setupFilerWithExtensionsTest(t *testing.T) filer.Filer {
{"p1.ipynb", readFile(t, "testdata/notebooks/py1.ipynb")},
{"r1.ipynb", readFile(t, "testdata/notebooks/r1.ipynb")},
{"scala1.ipynb", readFile(t, "testdata/notebooks/scala1.ipynb")},
{"sql1.ipynb", readFile(t, "testdata/notebooks/sql1.ipynb")},
{"pretender", "not a notebook"},
{"dir/file.txt", "file content"},
{"scala-notebook.scala", "// Databricks notebook source\nprintln('first upload')"},
@ -649,13 +669,15 @@ func TestAccFilerWorkspaceFilesExtensionsRead(t *testing.T) {
// Read contents of test fixtures as a sanity check.
filerTest{t, wf}.assertContents(ctx, "foo.py", "# Databricks notebook source\nprint('first upload'))")
filerTest{t, wf}.assertContents(ctx, "bar.py", "print('foo')")
filerTest{t, wf}.assertContentsJupyter(ctx, "p1.ipynb", "python")
filerTest{t, wf}.assertContentsJupyter(ctx, "r1.ipynb", "R")
filerTest{t, wf}.assertContentsJupyter(ctx, "scala1.ipynb", "scala")
filerTest{t, wf}.assertContents(ctx, "dir/file.txt", "file content")
filerTest{t, wf}.assertContents(ctx, "scala-notebook.scala", "// Databricks notebook source\nprintln('first upload')")
filerTest{t, wf}.assertContents(ctx, "pretender", "not a notebook")
filerTest{t, wf}.assertContentsJupyter(ctx, "p1.ipynb", "python")
filerTest{t, wf}.assertContentsJupyter(ctx, "r1.ipynb", "R")
filerTest{t, wf}.assertContentsJupyter(ctx, "scala1.ipynb", "scala")
filerTest{t, wf}.assertContentsJupyter(ctx, "sql1.ipynb", "sql")
// Read non-existent file
_, err := wf.Read(ctx, "non-existent.py")
assert.ErrorIs(t, err, fs.ErrNotExist)
@ -706,6 +728,11 @@ func TestAccFilerWorkspaceFilesExtensionsDelete(t *testing.T) {
require.NoError(t, err)
filerTest{t, wf}.assertNotExists(ctx, "scala1.ipynb")
// Delete sql jupyter notebook
err = wf.Delete(ctx, "sql1.ipynb")
require.NoError(t, err)
filerTest{t, wf}.assertNotExists(ctx, "sql1.ipynb")
// Delete non-existent file
err = wf.Delete(ctx, "non-existent.py")
assert.ErrorIs(t, err, fs.ErrNotExist)
@ -734,56 +761,45 @@ func TestAccFilerWorkspaceFilesExtensionsStat(t *testing.T) {
ctx := context.Background()
wf := setupFilerWithExtensionsTest(t)
// Stat on a notebook
info, err := wf.Stat(ctx, "foo.py")
require.NoError(t, err)
assert.Equal(t, "foo.py", info.Name())
assert.False(t, info.IsDir())
// Stat on a file
info, err = wf.Stat(ctx, "bar.py")
require.NoError(t, err)
assert.Equal(t, "bar.py", info.Name())
assert.False(t, info.IsDir())
// Stat on a python jupyter notebook
info, err = wf.Stat(ctx, "p1.ipynb")
require.NoError(t, err)
assert.Equal(t, "p1.ipynb", info.Name())
assert.False(t, info.IsDir())
// Stat on an R jupyter notebook
info, err = wf.Stat(ctx, "r1.ipynb")
require.NoError(t, err)
assert.Equal(t, "r1.ipynb", info.Name())
assert.False(t, info.IsDir())
// Stat on a Scala jupyter notebook
info, err = wf.Stat(ctx, "scala1.ipynb")
require.NoError(t, err)
assert.Equal(t, "scala1.ipynb", info.Name())
assert.False(t, info.IsDir())
for _, fileName := range []string{
// notebook
"foo.py",
// file
"bar.py",
// python jupyter notebook
"p1.ipynb",
// R jupyter notebook
"r1.ipynb",
// Scala jupyter notebook
"scala1.ipynb",
// SQL jupyter notebook
"sql1.ipynb",
} {
info, err := wf.Stat(ctx, fileName)
require.NoError(t, err)
assert.Equal(t, fileName, info.Name())
assert.False(t, info.IsDir())
}
// Stat on a directory
info, err = wf.Stat(ctx, "dir")
info, err := wf.Stat(ctx, "dir")
require.NoError(t, err)
assert.Equal(t, "dir", info.Name())
assert.True(t, info.IsDir())
// Stat on a non-existent file
_, err = wf.Stat(ctx, "non-existent.py")
assert.ErrorIs(t, err, fs.ErrNotExist)
// Ensure we do not stat a file as a notebook
_, err = wf.Stat(ctx, "pretender.py")
assert.ErrorIs(t, err, fs.ErrNotExist)
// Ensure we do not stat a Scala notebook as a Python notebook
_, err = wf.Stat(ctx, "scala-notebook.py")
assert.ErrorIs(t, err, fs.ErrNotExist)
_, err = wf.Stat(ctx, "pretender.ipynb")
assert.ErrorIs(t, err, fs.ErrNotExist)
for _, fileName := range []string{
// non-existent file
"non-existent.py",
// do not stat a file as a notebook
"pretender.py",
// do not stat a Scala notebook as a Python notebook
"scala-notebook.py",
// do not read a regular file as a Jupyter notebook
"pretender.ipynb",
} {
_, err := wf.Stat(ctx, fileName)
assert.ErrorIs(t, err, fs.ErrNotExist)
}
}
func TestAccWorkspaceFilesExtensionsDirectoriesAreNotNotebooks(t *testing.T) {
@ -830,8 +846,16 @@ func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
sourceContent: "// Databricks notebook source\nprintln('foo')",
jupyterName: "foo.ipynb",
},
{
language: "sql",
sourceName: "foo.sql",
sourceContent: "-- Databricks notebook source\nselect 'foo'",
jupyterName: "foo.ipynb",
},
} {
t.Run("source_"+tc.language, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
wf, _ := setupWsfsExtensionsFiler(t)
@ -858,24 +882,32 @@ func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
}{
{
language: "python",
sourceName: "bar.py",
sourceName: "foo.py",
jupyterName: "foo.ipynb",
jupyterContent: readFile(t, "testdata/notebooks/py1.ipynb"),
},
{
language: "R",
sourceName: "bar.r",
sourceName: "foo.r",
jupyterName: "foo.ipynb",
jupyterContent: readFile(t, "testdata/notebooks/r1.ipynb"),
},
{
language: "scala",
sourceName: "bar.scala",
sourceName: "foo.scala",
jupyterName: "foo.ipynb",
jupyterContent: readFile(t, "testdata/notebooks/scala1.ipynb"),
},
{
language: "sql",
sourceName: "foo.sql",
jupyterName: "foo.ipynb",
jupyterContent: readFile(t, "testdata/notebooks/sql1.ipynb"),
},
} {
t.Run("jupyter_"+tc.language, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
wf, _ := setupWsfsExtensionsFiler(t)

20
internal/testdata/notebooks/sql1.ipynb vendored Normal file
View File

@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"select 1"
]
}
],
"metadata": {
"language_info": {
"name": "sql"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

20
internal/testdata/notebooks/sql2.ipynb vendored Normal file
View File

@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"select 2"
]
}
],
"metadata": {
"language_info": {
"name": "sql"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -7,6 +7,7 @@ import (
"io"
"io/fs"
"path"
"slices"
"strings"
"github.com/databricks/cli/libs/log"
@ -23,17 +24,6 @@ type workspaceFilesExtensionsClient struct {
readonly bool
}
var extensionsToLanguages = map[string]workspace.Language{
".py": workspace.LanguagePython,
".r": workspace.LanguageR,
".scala": workspace.LanguageScala,
".sql": workspace.LanguageSql,
// The platform supports all languages (Python, R, Scala, and SQL) for Jupyter notebooks.
// Thus, we do not need to check the language for .ipynb files.
".ipynb": workspace.LanguagePython,
}
type workspaceFileStatus struct {
wsfsFileInfo
@ -57,11 +47,12 @@ func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string)
// This function returns the stat for the provided notebook. The stat object itself contains the path
// with the extension since it is meant to be used in the context of a fs.FileInfo.
func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) {
ext := path.Ext(name)
nameWithoutExt := strings.TrimSuffix(name, ext)
// TODO: What happens when this type casting is not possible?
ext := notebook.Extension(path.Ext(name))
nameWithoutExt := strings.TrimSuffix(name, string(ext))
// File name does not have an extension associated with Databricks notebooks, return early.
if _, ok := extensionsToLanguages[ext]; !ok {
if _, ok := notebook.ExtensionToLanguage[ext]; !ok {
return nil, nil
}
@ -84,28 +75,33 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx contex
// Not the correct language. Return early. Note: All languages are supported
// for Jupyter notebooks.
if ext != ".ipynb" && stat.Language != extensionsToLanguages[ext] {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), extensionsToLanguages[ext], stat.Language)
if ext != notebook.ExtensionJupyter && stat.Language != notebook.ExtensionToLanguage[ext] {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), notebook.ExtensionToLanguage[ext], stat.Language)
return nil, nil
}
// When the extension is .py we expect the export format to be source.
// When the extension is one of .py, .r, .scala or .sql we expect the export format to be source.
// If it's not, return early.
if ext == ".py" && stat.ReposExportFormat != workspace.ExportFormatSource {
if slices.Contains([]notebook.Extension{
notebook.ExtensionPython,
notebook.ExtensionR,
notebook.ExtensionScala,
notebook.ExtensionSql}, ext) &&
stat.ReposExportFormat != workspace.ExportFormatSource {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a source notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
return nil, nil
}
// When the extension is .ipynb we expect the export format to be Jupyter.
// If it's not, return early.
if ext == ".ipynb" && stat.ReposExportFormat != workspace.ExportFormatJupyter {
if ext == notebook.ExtensionJupyter && stat.ReposExportFormat != workspace.ExportFormatJupyter {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a Jupyter notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
return nil, nil
}
// Modify the stat object path to include the extension. This stat object will be used
// to return the fs.FileInfo object in the stat method.
stat.Path = stat.Path + ext
stat.Path = stat.Path + string(ext)
return &workspaceFileStatus{
wsfsFileInfo: stat,
nameForWorkspaceAPI: nameWithoutExt,
@ -130,12 +126,12 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx con
// If the notebook was exported as a Jupyter notebook, the extension should be .ipynb.
// TODO: Test this.
if stat.ReposExportFormat == workspace.ExportFormatJupyter {
ext = ".ipynb"
ext = notebook.ExtensionJupyter
}
// Modify the stat object path to include the extension. This stat object will be used
// to return the fs.DirEntry object in the ReadDir method.
stat.Path = stat.Path + ext
stat.Path = stat.Path + string(ext)
return &workspaceFileStatus{
wsfsFileInfo: stat,
nameForWorkspaceAPI: name,

View File

@ -2,22 +2,43 @@ package notebook
import "github.com/databricks/databricks-sdk-go/service/workspace"
func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) string {
type Extension string
const (
ExtensionNone Extension = ""
ExtensionPython Extension = ".py"
ExtensionR Extension = ".r"
ExtensionScala Extension = ".scala"
ExtensionSql Extension = ".sql"
ExtensionJupyter Extension = ".ipynb"
)
var ExtensionToLanguage = map[Extension]workspace.Language{
ExtensionPython: workspace.LanguagePython,
ExtensionR: workspace.LanguageR,
ExtensionScala: workspace.LanguageScala,
ExtensionSql: workspace.LanguageSql,
// The platform supports all languages (Python, R, Scala, and SQL) for Jupyter notebooks.
ExtensionJupyter: workspace.LanguageUnknown,
}
func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) Extension {
if objectInfo.ObjectType != workspace.ObjectTypeNotebook {
return ""
return ExtensionNone
}
switch objectInfo.Language {
case workspace.LanguagePython:
return ".py"
return ExtensionPython
case workspace.LanguageR:
return ".r"
return ExtensionR
case workspace.LanguageScala:
return ".scala"
return ExtensionScala
case workspace.LanguageSql:
return ".sql"
return ExtensionSql
default:
// Do not add any extension to the file name
return ""
return ExtensionNone
}
}