diff --git a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl index 04bb261cd..42164dff0 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/scratch/exploration.ipynb.tmpl @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "code", "execution_count": null, @@ -22,7 +32,7 @@ "sys.path.append('../src')\n", "from {{.project_name}} import main\n", "\n", - "main.get_taxis().show(10)" + "main.get_taxis(spark).show(10)" {{else}} "spark.range(10)" {{end -}} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl index 4f50294f6..b152e9a30 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/dlt_pipeline.ipynb.tmpl @@ -63,7 +63,7 @@ {{- if (eq .include_python "yes") }} "@dlt.view\n", "def taxi_raw():\n", - " return main.get_taxis()\n", + " return main.get_taxis(spark)\n", {{else}} "\n", "@dlt.view\n", diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl index 0ab61db2c..a228f8d18 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/notebook.ipynb.tmpl @@ -17,6 +17,16 @@ "This default notebook is executed using Databricks Workflows as defined in resources/{{.project_name}}_job.yml." ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "code", "execution_count": 0, @@ -37,7 +47,7 @@ {{- if (eq .include_python "yes") }} "from {{.project_name}} import main\n", "\n", - "main.get_taxis().show(10)" + "main.get_taxis(spark).show(10)" {{else}} "spark.range(10)" {{end -}} diff --git a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl index 4fe5ac8f4..48529e974 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/src/{{.project_name}}/main.py.tmpl @@ -1,16 +1,12 @@ -{{- /* -We use pyspark.sql rather than DatabricksSession.builder.getOrCreate() -for compatibility with older runtimes. With a new runtime, it's -equivalent to DatabricksSession.builder.getOrCreate(). -*/ -}} from pyspark.sql import SparkSession -def get_taxis(): - spark = SparkSession.builder.getOrCreate() +def get_taxis(spark: SparkSession): return spark.read.table("samples.nyctaxi.trips") def main(): - get_taxis().show(5) + from databricks.connect import DatabricksSession as SparkSession + spark = SparkSession.builder.getOrCreate() + get_taxis(spark).show(5) if __name__ == '__main__': main() diff --git a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl index a7a6afe0a..8ae043a65 100644 --- a/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl +++ b/libs/template/templates/default-python/template/{{.project_name}}/tests/main_test.py.tmpl @@ -1,21 +1,15 @@ -from databricks.connect import DatabricksSession -from pyspark.sql import SparkSession +from databricks.connect import DatabricksSession as SparkSession +from pytest import fixture from {{.project_name}} import main -# Create a new Databricks Connect session. If this fails, -# check that you have configured Databricks Connect correctly. -# See https://docs.databricks.com/dev-tools/databricks-connect.html. -{{/* - The below works around a problematic error message from Databricks Connect. - The standard SparkSession is supported in all configurations (workspace, IDE, - all runtime versions, CLI). But on the CLI it currently gives a confusing - error message if SPARK_REMOTE is not set. We can't directly use - DatabricksSession.builder in main.py, so we're re-assigning it here so - everything works out of the box, even for CLI users who don't set SPARK_REMOTE. -*/}} -SparkSession.builder = DatabricksSession.builder -SparkSession.builder.getOrCreate() -def test_main(): - taxis = main.get_taxis() +@fixture(scope="session") +def spark(): + spark = SparkSession.builder.getOrCreate() + yield spark + spark.stop() + + +def test_main(spark: SparkSession): + taxis = main.get_taxis(spark) assert taxis.count() > 5