mirror of https://github.com/databricks/cli.git
add test for sql as well as enum for extensions
This commit is contained in:
parent
cd8cc2c531
commit
20a30f7c6f
|
@ -71,7 +71,7 @@ func (n *downloader) markNotebookForDownload(ctx context.Context, notebookPath *
|
|||
|
||||
ext := notebook.GetExtensionByLanguage(info)
|
||||
|
||||
filename := path.Base(*notebookPath) + ext
|
||||
filename := path.Base(*notebookPath) + string(ext)
|
||||
targetPath := filepath.Join(n.sourceDir, filename)
|
||||
|
||||
n.files[targetPath] = *notebookPath
|
||||
|
|
|
@ -48,7 +48,7 @@ func (opts exportDirOptions) callback(ctx context.Context, workspaceFiler filer.
|
|||
return err
|
||||
}
|
||||
objectInfo := info.Sys().(workspace.ObjectInfo)
|
||||
targetPath += notebook.GetExtensionByLanguage(&objectInfo)
|
||||
targetPath += string(notebook.GetExtensionByLanguage(&objectInfo))
|
||||
|
||||
// Skip file if a file already exists in path.
|
||||
// os.Stat returns a fs.ErrNotExist if a file does not exist at path.
|
||||
|
|
|
@ -427,8 +427,17 @@ func TestAccFilerWorkspaceNotebookConflict(t *testing.T) {
|
|||
expected1: "// Databricks notebook source\nprintln(1)",
|
||||
content2: readFile(t, "testdata/notebooks/scala2.ipynb"),
|
||||
},
|
||||
{
|
||||
name: "sqlJupyterNotebook.ipynb",
|
||||
nameWithoutExt: "sqlJupyterNotebook",
|
||||
content1: readFile(t, "testdata/notebooks/sql1.ipynb"),
|
||||
expected1: "-- Databricks notebook source\nselect 1",
|
||||
content2: readFile(t, "testdata/notebooks/sql2.ipynb"),
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Upload the notebook
|
||||
err = f.Write(ctx, tc.name, strings.NewReader(tc.content1))
|
||||
require.NoError(t, err)
|
||||
|
@ -516,8 +525,18 @@ func TestAccFilerWorkspaceNotebookWithOverwriteFlag(t *testing.T) {
|
|||
content2: readFile(t, "testdata/notebooks/scala2.ipynb"),
|
||||
expected2: "// Databricks notebook source\nprintln(2)",
|
||||
},
|
||||
{
|
||||
name: "sqlJupyterNotebook.ipynb",
|
||||
nameWithoutExt: "sqlJupyterNotebook",
|
||||
content1: readFile(t, "testdata/notebooks/sql1.ipynb"),
|
||||
expected1: "-- Databricks notebook source\nselect 1",
|
||||
content2: readFile(t, "testdata/notebooks/sql2.ipynb"),
|
||||
expected2: "-- Databricks notebook source\nselect 2",
|
||||
},
|
||||
} {
|
||||
t.Run(tcases.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Upload the notebook
|
||||
err = f.Write(ctx, tcases.name, strings.NewReader(tcases.content1))
|
||||
require.NoError(t, err)
|
||||
|
@ -535,8 +554,6 @@ func TestAccFilerWorkspaceNotebookWithOverwriteFlag(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Add a test that the exported file has the right extension / language_info set?
|
||||
// Required for DABs in the workspace.
|
||||
func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -557,6 +574,7 @@ func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
|
|||
{"rNb.r", "# Databricks notebook source\nprint('first upload'))"},
|
||||
{"scala1.ipynb", readFile(t, "testdata/notebooks/scala1.ipynb")},
|
||||
{"scalaNb.scala", "// Databricks notebook source\n println(\"first upload\"))"},
|
||||
{"sql1.ipynb", readFile(t, "testdata/notebooks/sql1.ipynb")},
|
||||
{"sqlNb.sql", "-- Databricks notebook source\n SELECT \"first upload\""},
|
||||
}
|
||||
|
||||
|
@ -598,6 +616,7 @@ func TestAccFilerWorkspaceFilesExtensionsReadDir(t *testing.T) {
|
|||
"rNb.r",
|
||||
"scala1.ipynb",
|
||||
"scalaNb.scala",
|
||||
"sql1.ipynb",
|
||||
"sqlNb.sql",
|
||||
}, names)
|
||||
|
||||
|
@ -624,6 +643,7 @@ func setupFilerWithExtensionsTest(t *testing.T) filer.Filer {
|
|||
{"p1.ipynb", readFile(t, "testdata/notebooks/py1.ipynb")},
|
||||
{"r1.ipynb", readFile(t, "testdata/notebooks/r1.ipynb")},
|
||||
{"scala1.ipynb", readFile(t, "testdata/notebooks/scala1.ipynb")},
|
||||
{"sql1.ipynb", readFile(t, "testdata/notebooks/sql1.ipynb")},
|
||||
{"pretender", "not a notebook"},
|
||||
{"dir/file.txt", "file content"},
|
||||
{"scala-notebook.scala", "// Databricks notebook source\nprintln('first upload')"},
|
||||
|
@ -649,13 +669,15 @@ func TestAccFilerWorkspaceFilesExtensionsRead(t *testing.T) {
|
|||
// Read contents of test fixtures as a sanity check.
|
||||
filerTest{t, wf}.assertContents(ctx, "foo.py", "# Databricks notebook source\nprint('first upload'))")
|
||||
filerTest{t, wf}.assertContents(ctx, "bar.py", "print('foo')")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "p1.ipynb", "python")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "r1.ipynb", "R")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "scala1.ipynb", "scala")
|
||||
filerTest{t, wf}.assertContents(ctx, "dir/file.txt", "file content")
|
||||
filerTest{t, wf}.assertContents(ctx, "scala-notebook.scala", "// Databricks notebook source\nprintln('first upload')")
|
||||
filerTest{t, wf}.assertContents(ctx, "pretender", "not a notebook")
|
||||
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "p1.ipynb", "python")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "r1.ipynb", "R")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "scala1.ipynb", "scala")
|
||||
filerTest{t, wf}.assertContentsJupyter(ctx, "sql1.ipynb", "sql")
|
||||
|
||||
// Read non-existent file
|
||||
_, err := wf.Read(ctx, "non-existent.py")
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
|
@ -706,6 +728,11 @@ func TestAccFilerWorkspaceFilesExtensionsDelete(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
filerTest{t, wf}.assertNotExists(ctx, "scala1.ipynb")
|
||||
|
||||
// Delete sql jupyter notebook
|
||||
err = wf.Delete(ctx, "sql1.ipynb")
|
||||
require.NoError(t, err)
|
||||
filerTest{t, wf}.assertNotExists(ctx, "sql1.ipynb")
|
||||
|
||||
// Delete non-existent file
|
||||
err = wf.Delete(ctx, "non-existent.py")
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
|
@ -734,56 +761,45 @@ func TestAccFilerWorkspaceFilesExtensionsStat(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
wf := setupFilerWithExtensionsTest(t)
|
||||
|
||||
// Stat on a notebook
|
||||
info, err := wf.Stat(ctx, "foo.py")
|
||||
for _, fileName := range []string{
|
||||
// notebook
|
||||
"foo.py",
|
||||
// file
|
||||
"bar.py",
|
||||
// python jupyter notebook
|
||||
"p1.ipynb",
|
||||
// R jupyter notebook
|
||||
"r1.ipynb",
|
||||
// Scala jupyter notebook
|
||||
"scala1.ipynb",
|
||||
// SQL jupyter notebook
|
||||
"sql1.ipynb",
|
||||
} {
|
||||
info, err := wf.Stat(ctx, fileName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "foo.py", info.Name())
|
||||
assert.False(t, info.IsDir())
|
||||
|
||||
// Stat on a file
|
||||
info, err = wf.Stat(ctx, "bar.py")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bar.py", info.Name())
|
||||
assert.False(t, info.IsDir())
|
||||
|
||||
// Stat on a python jupyter notebook
|
||||
info, err = wf.Stat(ctx, "p1.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "p1.ipynb", info.Name())
|
||||
assert.False(t, info.IsDir())
|
||||
|
||||
// Stat on an R jupyter notebook
|
||||
info, err = wf.Stat(ctx, "r1.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "r1.ipynb", info.Name())
|
||||
assert.False(t, info.IsDir())
|
||||
|
||||
// Stat on a Scala jupyter notebook
|
||||
info, err = wf.Stat(ctx, "scala1.ipynb")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "scala1.ipynb", info.Name())
|
||||
assert.Equal(t, fileName, info.Name())
|
||||
assert.False(t, info.IsDir())
|
||||
}
|
||||
|
||||
// Stat on a directory
|
||||
info, err = wf.Stat(ctx, "dir")
|
||||
info, err := wf.Stat(ctx, "dir")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "dir", info.Name())
|
||||
assert.True(t, info.IsDir())
|
||||
|
||||
// Stat on a non-existent file
|
||||
_, err = wf.Stat(ctx, "non-existent.py")
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
|
||||
// Ensure we do not stat a file as a notebook
|
||||
_, err = wf.Stat(ctx, "pretender.py")
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
|
||||
// Ensure we do not stat a Scala notebook as a Python notebook
|
||||
_, err = wf.Stat(ctx, "scala-notebook.py")
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
|
||||
_, err = wf.Stat(ctx, "pretender.ipynb")
|
||||
for _, fileName := range []string{
|
||||
// non-existent file
|
||||
"non-existent.py",
|
||||
// do not stat a file as a notebook
|
||||
"pretender.py",
|
||||
// do not stat a Scala notebook as a Python notebook
|
||||
"scala-notebook.py",
|
||||
// do not read a regular file as a Jupyter notebook
|
||||
"pretender.ipynb",
|
||||
} {
|
||||
_, err := wf.Stat(ctx, fileName)
|
||||
assert.ErrorIs(t, err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccWorkspaceFilesExtensionsDirectoriesAreNotNotebooks(t *testing.T) {
|
||||
|
@ -830,8 +846,16 @@ func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
|
|||
sourceContent: "// Databricks notebook source\nprintln('foo')",
|
||||
jupyterName: "foo.ipynb",
|
||||
},
|
||||
{
|
||||
language: "sql",
|
||||
sourceName: "foo.sql",
|
||||
sourceContent: "-- Databricks notebook source\nselect 'foo'",
|
||||
jupyterName: "foo.ipynb",
|
||||
},
|
||||
} {
|
||||
t.Run("source_"+tc.language, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
wf, _ := setupWsfsExtensionsFiler(t)
|
||||
|
||||
|
@ -858,24 +882,32 @@ func TestAccWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
language: "python",
|
||||
sourceName: "bar.py",
|
||||
sourceName: "foo.py",
|
||||
jupyterName: "foo.ipynb",
|
||||
jupyterContent: readFile(t, "testdata/notebooks/py1.ipynb"),
|
||||
},
|
||||
{
|
||||
language: "R",
|
||||
sourceName: "bar.r",
|
||||
sourceName: "foo.r",
|
||||
jupyterName: "foo.ipynb",
|
||||
jupyterContent: readFile(t, "testdata/notebooks/r1.ipynb"),
|
||||
},
|
||||
{
|
||||
language: "scala",
|
||||
sourceName: "bar.scala",
|
||||
sourceName: "foo.scala",
|
||||
jupyterName: "foo.ipynb",
|
||||
jupyterContent: readFile(t, "testdata/notebooks/scala1.ipynb"),
|
||||
},
|
||||
{
|
||||
language: "sql",
|
||||
sourceName: "foo.sql",
|
||||
jupyterName: "foo.ipynb",
|
||||
jupyterContent: readFile(t, "testdata/notebooks/sql1.ipynb"),
|
||||
},
|
||||
} {
|
||||
t.Run("jupyter_"+tc.language, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
wf, _ := setupWsfsExtensionsFiler(t)
|
||||
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"select 1"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "sql"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"select 2"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "sql"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"io/fs"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/databricks/cli/libs/log"
|
||||
|
@ -23,17 +24,6 @@ type workspaceFilesExtensionsClient struct {
|
|||
readonly bool
|
||||
}
|
||||
|
||||
var extensionsToLanguages = map[string]workspace.Language{
|
||||
".py": workspace.LanguagePython,
|
||||
".r": workspace.LanguageR,
|
||||
".scala": workspace.LanguageScala,
|
||||
".sql": workspace.LanguageSql,
|
||||
|
||||
// The platform supports all languages (Python, R, Scala, and SQL) for Jupyter notebooks.
|
||||
// Thus, we do not need to check the language for .ipynb files.
|
||||
".ipynb": workspace.LanguagePython,
|
||||
}
|
||||
|
||||
type workspaceFileStatus struct {
|
||||
wsfsFileInfo
|
||||
|
||||
|
@ -57,11 +47,12 @@ func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string)
|
|||
// This function returns the stat for the provided notebook. The stat object itself contains the path
|
||||
// with the extension since it is meant to be used in the context of a fs.FileInfo.
|
||||
func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) {
|
||||
ext := path.Ext(name)
|
||||
nameWithoutExt := strings.TrimSuffix(name, ext)
|
||||
// TODO: What happens when this type casting is not possible?
|
||||
ext := notebook.Extension(path.Ext(name))
|
||||
nameWithoutExt := strings.TrimSuffix(name, string(ext))
|
||||
|
||||
// File name does not have an extension associated with Databricks notebooks, return early.
|
||||
if _, ok := extensionsToLanguages[ext]; !ok {
|
||||
if _, ok := notebook.ExtensionToLanguage[ext]; !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -84,28 +75,33 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx contex
|
|||
|
||||
// Not the correct language. Return early. Note: All languages are supported
|
||||
// for Jupyter notebooks.
|
||||
if ext != ".ipynb" && stat.Language != extensionsToLanguages[ext] {
|
||||
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), extensionsToLanguages[ext], stat.Language)
|
||||
if ext != notebook.ExtensionJupyter && stat.Language != notebook.ExtensionToLanguage[ext] {
|
||||
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), notebook.ExtensionToLanguage[ext], stat.Language)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// When the extension is .py we expect the export format to be source.
|
||||
// When the extension is one of .py, .r, .scala or .sql we expect the export format to be source.
|
||||
// If it's not, return early.
|
||||
if ext == ".py" && stat.ReposExportFormat != workspace.ExportFormatSource {
|
||||
if slices.Contains([]notebook.Extension{
|
||||
notebook.ExtensionPython,
|
||||
notebook.ExtensionR,
|
||||
notebook.ExtensionScala,
|
||||
notebook.ExtensionSql}, ext) &&
|
||||
stat.ReposExportFormat != workspace.ExportFormatSource {
|
||||
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a source notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// When the extension is .ipynb we expect the export format to be Jupyter.
|
||||
// If it's not, return early.
|
||||
if ext == ".ipynb" && stat.ReposExportFormat != workspace.ExportFormatJupyter {
|
||||
if ext == notebook.ExtensionJupyter && stat.ReposExportFormat != workspace.ExportFormatJupyter {
|
||||
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a Jupyter notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Modify the stat object path to include the extension. This stat object will be used
|
||||
// to return the fs.FileInfo object in the stat method.
|
||||
stat.Path = stat.Path + ext
|
||||
stat.Path = stat.Path + string(ext)
|
||||
return &workspaceFileStatus{
|
||||
wsfsFileInfo: stat,
|
||||
nameForWorkspaceAPI: nameWithoutExt,
|
||||
|
@ -130,12 +126,12 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx con
|
|||
// If the notebook was exported as a Jupyter notebook, the extension should be .ipynb.
|
||||
// TODO: Test this.
|
||||
if stat.ReposExportFormat == workspace.ExportFormatJupyter {
|
||||
ext = ".ipynb"
|
||||
ext = notebook.ExtensionJupyter
|
||||
}
|
||||
|
||||
// Modify the stat object path to include the extension. This stat object will be used
|
||||
// to return the fs.DirEntry object in the ReadDir method.
|
||||
stat.Path = stat.Path + ext
|
||||
stat.Path = stat.Path + string(ext)
|
||||
return &workspaceFileStatus{
|
||||
wsfsFileInfo: stat,
|
||||
nameForWorkspaceAPI: name,
|
||||
|
|
|
@ -2,22 +2,43 @@ package notebook
|
|||
|
||||
import "github.com/databricks/databricks-sdk-go/service/workspace"
|
||||
|
||||
func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) string {
|
||||
type Extension string
|
||||
|
||||
const (
|
||||
ExtensionNone Extension = ""
|
||||
ExtensionPython Extension = ".py"
|
||||
ExtensionR Extension = ".r"
|
||||
ExtensionScala Extension = ".scala"
|
||||
ExtensionSql Extension = ".sql"
|
||||
ExtensionJupyter Extension = ".ipynb"
|
||||
)
|
||||
|
||||
var ExtensionToLanguage = map[Extension]workspace.Language{
|
||||
ExtensionPython: workspace.LanguagePython,
|
||||
ExtensionR: workspace.LanguageR,
|
||||
ExtensionScala: workspace.LanguageScala,
|
||||
ExtensionSql: workspace.LanguageSql,
|
||||
|
||||
// The platform supports all languages (Python, R, Scala, and SQL) for Jupyter notebooks.
|
||||
ExtensionJupyter: workspace.LanguageUnknown,
|
||||
}
|
||||
|
||||
func GetExtensionByLanguage(objectInfo *workspace.ObjectInfo) Extension {
|
||||
if objectInfo.ObjectType != workspace.ObjectTypeNotebook {
|
||||
return ""
|
||||
return ExtensionNone
|
||||
}
|
||||
|
||||
switch objectInfo.Language {
|
||||
case workspace.LanguagePython:
|
||||
return ".py"
|
||||
return ExtensionPython
|
||||
case workspace.LanguageR:
|
||||
return ".r"
|
||||
return ExtensionR
|
||||
case workspace.LanguageScala:
|
||||
return ".scala"
|
||||
return ExtensionScala
|
||||
case workspace.LanguageSql:
|
||||
return ".sql"
|
||||
return ExtensionSql
|
||||
default:
|
||||
// Do not add any extension to the file name
|
||||
return ""
|
||||
return ExtensionNone
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue