Fix DBConnect support in VS Code (#1253)

## Changes

With the current template, we can't execute the Python file and the jobs
notebook using DBConnect from VSCode because we import `from pyspark.sql
import SparkSession`, which doesn't support Databricks unified auth.
This PR fixes this by passing spark into the library code and by
explicitly instantiating a spark session where the spark global is not
available.

Other changes:

* add auto-reload to notebooks
* add DLT typings for code completion
This commit is contained in:
Fabian Jakobs 2024-03-05 15:31:27 +01:00 committed by GitHub
parent ecf9c52f61
commit e61f0e1eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 29 deletions

View File

@ -3,6 +3,9 @@
## For defining dependencies used by jobs in Databricks Workflows, see
## https://docs.databricks.com/dev-tools/bundles/library-dependencies.html
## Add code completion support for DLT
databricks-dlt
## pytest is the default package used for testing
pytest

View File

@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -22,7 +32,7 @@
"sys.path.append('../src')\n",
"from {{.project_name}} import main\n",
"\n",
"main.get_taxis().show(10)"
"main.get_taxis(spark).show(10)"
{{else}}
"spark.range(10)"
{{end -}}

View File

@ -63,7 +63,7 @@
{{- if (eq .include_python "yes") }}
"@dlt.view\n",
"def taxi_raw():\n",
" return main.get_taxis()\n",
" return main.get_taxis(spark)\n",
{{else}}
"\n",
"@dlt.view\n",

View File

@ -17,6 +17,16 @@
"This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 0,
@ -37,7 +47,7 @@
{{- if (eq .include_python "yes") }}
"from {{.project_name}} import main\n",
"\n",
"main.get_taxis().show(10)"
"main.get_taxis(spark).show(10)"
{{else}}
"spark.range(10)"
{{end -}}

View File

@ -1,16 +1,21 @@
{{- /*
We use pyspark.sql rather than DatabricksSession.builder.getOrCreate()
for compatibility with older runtimes. With a new runtime, it's
equivalent to DatabricksSession.builder.getOrCreate().
*/ -}}
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, DataFrame
def get_taxis():
spark = SparkSession.builder.getOrCreate()
def get_taxis(spark: SparkSession) -> DataFrame:
return spark.read.table("samples.nyctaxi.trips")
# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
def get_spark() -> SparkSession:
try:
from databricks.connect import DatabricksSession
return DatabricksSession.builder.getOrCreate()
except ImportError:
return SparkSession.builder.getOrCreate()
def main():
get_taxis().show(5)
get_taxis(get_spark()).show(5)
if __name__ == '__main__':
main()

View File

@ -1,21 +1,6 @@
from databricks.connect import DatabricksSession
from pyspark.sql import SparkSession
from {{.project_name}} import main
from {{.project_name}}.main import get_taxis, get_spark
# Create a new Databricks Connect session. If this fails,
# check that you have configured Databricks Connect correctly.
# See https://docs.databricks.com/dev-tools/databricks-connect.html.
{{/*
The below works around a problematic error message from Databricks Connect.
The standard SparkSession is supported in all configurations (workspace, IDE,
all runtime versions, CLI). But on the CLI it currently gives a confusing
error message if SPARK_REMOTE is not set. We can't directly use
DatabricksSession.builder in main.py, so we're re-assigning it here so
everything works out of the box, even for CLI users who don't set SPARK_REMOTE.
*/}}
SparkSession.builder = DatabricksSession.builder
SparkSession.builder.getOrCreate()
def test_main():
taxis = main.get_taxis()
taxis = get_taxis(get_spark())
assert taxis.count() > 5