diff --git a/experimental/dbx2dab/dbx2dab/compare.py b/experimental/dbx2dab/dbx2dab/compare.py index 3bc9dc92..32b0a626 100644 --- a/experimental/dbx2dab/dbx2dab/compare.py +++ b/experimental/dbx2dab/dbx2dab/compare.py @@ -393,3 +393,41 @@ def recursive_merge_list(list1: List[any], list2: List[any]): else: merged.append(item) return merged + + +class Walker: + _callback = None + + def __init__(self, callback=None): + self._callback = callback + + def walk(self, obj, path=None): + if path is None: + path = [] + + if isinstance(obj, dict): + return self._walk_dict(obj, path) + elif isinstance(obj, list): + return self._walk_list(obj, path) + else: + return self._walk_scalar(obj, path) + + def _walk_dict(self, obj, path): + for key in obj: + obj[key] = self.walk(obj[key], path + [key]) + return obj + + def _walk_list(self, obj, path): + for i, item in enumerate(obj): + obj[i] = self.walk(item, path + [i]) + return obj + + def _walk_scalar(self, obj, path): + if self._callback: + return self._callback(path, obj) + return obj + + +def walk(obj, callback=None): + walker = Walker(callback) + return walker.walk(obj) diff --git a/experimental/dbx2dab/main.py b/experimental/dbx2dab/main.py index c00e4833..7834bc3f 100644 --- a/experimental/dbx2dab/main.py +++ b/experimental/dbx2dab/main.py @@ -1,6 +1,8 @@ import argparse +import dataclasses import sys import re +import copy from pathlib import Path from typing import Dict @@ -12,6 +14,7 @@ from dbx2dab.compare import ( recursive_intersection, recursive_subtract, recursive_merge, + walk, ) from dbx2dab.loader import Loader @@ -112,6 +115,127 @@ class Job: } +class LookupRewriter: + @dataclasses.dataclass + class RewriteType: + variable_name_suffix: str + object_type: str + + _prefixes = { + "cluster://": RewriteType( + variable_name_suffix="cluster_id", + object_type="cluster", + ), + "cluster-policy://": RewriteType( + variable_name_suffix="cluster_policy_id", + object_type="cluster_policy", + ), + "instance-profile://": None, + "instance-pool://": RewriteType( + variable_name_suffix="instance_pool_id", + object_type="instance_pool", + ), + "pipeline://": None, + "service-principal://": RewriteType( + variable_name_suffix="service_principal_id", + object_type="service_principal", + ), + "warehouse://": RewriteType( + variable_name_suffix="warehouse_id", + object_type="warehouse", + ), + "query://": None, + "dashboard://": None, + "alert://": None, + } + + def __init__(self, job: Job) -> None: + """ + One instance per job. + We track all references by the env they appear in so we can differentiate between them if needed. + """ + self.job = job + self.variables = {} + + def add(self, env: str): + def cb(path, obj): + if isinstance(obj, str): + for prefix in self._prefixes.keys(): + if obj.startswith(prefix): + payload = obj.replace(prefix, "") + if prefix in self.variables[env]: + raise ValueError( + f"Duplicate variable reference for {prefix} in {env}" + ) + self.variables[env][str(path)] = [prefix, payload] + break + return obj + + self.variables[env] = dict() + walk(self.job.configs[env], cb) + + def confirm_envs_are_idential(self) -> Dict[str, any]: + # Run a deep equal on the dicts for every env + keys = list(self.variables.keys()) + first = self.variables[keys[0]] + for key in keys[1:]: + diff = recursive_subtract(self.variables[key], first) + if diff: + raise ValueError("Variable references differ between environments") + return first + + def rewrite(self) -> Dict[str, any]: + """ + Returns variables of the form: + + { + "etl_cluster_policy_id": { + "description": "", + "lookup": { + "cluster_policy": "some_policy" + } + } + } + """ + + rewrites = dict() + variables = [] + + # Compile a list of variables and how to rewrite the existing instances + for path, (prefix, payload) in self.confirm_envs_are_idential().items(): + rewrite = self._prefixes[prefix] + if rewrite is None: + raise ValueError(f"Unhandled prefix: {prefix}") + + variable_name = ( + f"{self.job.normalized_key()}_{rewrite.variable_name_suffix}" + ) + + # Add rewrite for the path + rewrites[path] = f"${{var.{variable_name}}}" + + # Add variable for the lookup + variables.append( + { + "name": variable_name, + "lookup_type": rewrite.object_type, + "lookup_value": payload, + } + ) + + # Now rewrite the job configuration + def cb(path, obj): + rewrite = rewrites.get(str(path), None) + if rewrite is not None: + return rewrite + return obj + + for env in self.job.configs.keys(): + self.job.configs[env] = walk(copy.deepcopy(self.job.configs[env]), cb) + + return variables + + def dedup_variables(variables): deduped = dict() for v in variables: @@ -120,7 +244,9 @@ def dedup_variables(variables): return deduped.keys() -def save_databricks_yml(base_path: Path, env_variables, var_variables): +def save_databricks_yml( + base_path: Path, env_variables, var_variables, var_lookup_variables +): env = jinja2.Environment( loader=jinja2.FileSystemLoader(Path(__file__).parent.joinpath("templates")) ) @@ -135,6 +261,7 @@ def save_databricks_yml(base_path: Path, env_variables, var_variables): bundle_name=base_name, env_variables=env_variables, var_variables=var_variables, + var_lookup_variables=var_lookup_variables, ) ) @@ -167,6 +294,7 @@ def main(): env_variables = [] var_variables = [] + var_lookup_variables = [] jobs: Dict[str, Job] = dict() for env in envs: @@ -180,6 +308,13 @@ def main(): jobs[name].register_configuration(env, workflow) + # Locate variable lookups + for job in jobs.values(): + lr = LookupRewriter(job) + for env in job.configs: + lr.add(env) + var_lookup_variables.extend(lr.rewrite()) + for job in jobs.values(): base_job = job.compute_base() @@ -215,7 +350,7 @@ def main(): # Write variable definitions env_variables = dedup_variables(env_variables) var_variables = dedup_variables(var_variables) - save_databricks_yml(base_path, env_variables, var_variables) + save_databricks_yml(base_path, env_variables, var_variables, var_lookup_variables) # Write resource overrides for env in envs: diff --git a/experimental/dbx2dab/templates/databricks.yml.j2 b/experimental/dbx2dab/templates/databricks.yml.j2 index 9d9ae08f..296294a8 100644 --- a/experimental/dbx2dab/templates/databricks.yml.j2 +++ b/experimental/dbx2dab/templates/databricks.yml.j2 @@ -29,3 +29,12 @@ variables: {{ item }}: description: "" {% endfor %} + + # Variables for fixtures in the workspace that are resolved by name. + # The lookup value is defined below, but can be overridden in the target. + {% for obj in var_lookup_variables -%} + {{ obj["name"] }}: + description: "" + lookup: + {{ obj["lookup_type"] }}: {{ obj["lookup_value"] }} + {% endfor %}