Fix DBConnect support in VS Code (#1253)

## Changes

With the current template, we can't execute the Python file and the jobs
notebook using DBConnect from VSCode because we import `from pyspark.sql
import SparkSession`, which doesn't support Databricks unified auth.
This PR fixes this by passing spark into the library code and by
explicitly instantiating a spark session where the spark global is not
available.

Other changes:

* add auto-reload to notebooks
* add DLT typings for code completion
This commit is contained in:
Fabian Jakobs 2024-03-05 15:31:27 +01:00 committed by GitHub
parent ecf9c52f61
commit e61f0e1eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 29 deletions

View File

@ -3,6 +3,9 @@
## For defining dependencies used by jobs in Databricks Workflows, see ## For defining dependencies used by jobs in Databricks Workflows, see
## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html ## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html
## Add code completion support for DLT
databricks-dlt
## pytest is the default package used for testing ## pytest is the default package used for testing
pytest pytest

View File

@ -1,5 +1,15 @@
{ {
"cells": [ "cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -22,7 +32,7 @@
"sys.path.append('../src')\n", "sys.path.append('../src')\n",
"from {{.project_name}} import main\n", "from {{.project_name}} import main\n",
"\n", "\n",
"main.get_taxis().show(10)" "main.get_taxis(spark).show(10)"
{{else}} {{else}}
"spark.range(10)" "spark.range(10)"
{{end -}} {{end -}}

View File

@ -63,7 +63,7 @@
{{- if (eq .include_python "yes") }} {{- if (eq .include_python "yes") }}
"@dlt.view\n", "@dlt.view\n",
"def taxi_raw():\n", "def taxi_raw():\n",
" return main.get_taxis()\n", " return main.get_taxis(spark)\n",
{{else}} {{else}}
"\n", "\n",
"@dlt.view\n", "@dlt.view\n",

View File

@ -17,6 +17,16 @@
"This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml." "This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml."
] ]
}, },
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": 0,
@ -37,7 +47,7 @@
{{- if (eq .include_python "yes") }} {{- if (eq .include_python "yes") }}
"from {{.project_name}} import main\n", "from {{.project_name}} import main\n",
"\n", "\n",
"main.get_taxis().show(10)" "main.get_taxis(spark).show(10)"
{{else}} {{else}}
"spark.range(10)" "spark.range(10)"
{{end -}} {{end -}}

View File

@ -1,16 +1,21 @@
{{- /* from pyspark.sql import SparkSession, DataFrame
We use pyspark.sql rather than DatabricksSession.builder.getOrCreate()
for compatibility with older runtimes. With a new runtime, it's
equivalent to DatabricksSession.builder.getOrCreate().
*/ -}}
from pyspark.sql import SparkSession
def get_taxis(): def get_taxis(spark: SparkSession) -> DataFrame:
spark = SparkSession.builder.getOrCreate()
return spark.read.table("samples.nyctaxi.trips") return spark.read.table("samples.nyctaxi.trips")
# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
def get_spark() -> SparkSession:
try:
from databricks.connect import DatabricksSession
return DatabricksSession.builder.getOrCreate()
except ImportError:
return SparkSession.builder.getOrCreate()
def main(): def main():
get_taxis().show(5) get_taxis(get_spark()).show(5)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,21 +1,6 @@
from databricks.connect import DatabricksSession from {{.project_name}}.main import get_taxis, get_spark
from pyspark.sql import SparkSession
from {{.project_name}} import main
# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
{{/*
The below works around a problematic error message from Databricks Connect.
The standard SparkSession is supported in all configurations (workspace, IDE,
all runtime versions, CLI). But on the CLI it currently gives a confusing
error message if SPARK_REMOTE is not set. We can't directly use
DatabricksSession.builder in main.py, so we're re-assigning it here so
everything works out of the box, even for CLI users who don't set SPARK_REMOTE.
*/}}
SparkSession.builder = DatabricksSession.builder
SparkSession.builder.getOrCreate()
def test_main(): def test_main():
taxis = main.get_taxis() taxis = get_taxis(get_spark())
assert taxis.count() > 5 assert taxis.count() > 5