mirror of https://github.com/databricks/cli.git
394 lines
11 KiB
Python
394 lines
11 KiB
Python
import argparse
|
|
import dataclasses
|
|
import sys
|
|
import re
|
|
import copy
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import yaml
|
|
import jinja2
|
|
|
|
from dbx2dab.compare import (
|
|
recursive_compare,
|
|
recursive_intersection,
|
|
recursive_subtract,
|
|
recursive_merge,
|
|
walk,
|
|
)
|
|
|
|
from dbx2dab.loader import Loader
|
|
|
|
|
|
class VerboseSafeDumper(yaml.SafeDumper):
|
|
"""
|
|
A YAML dumper that does not use aliases.
|
|
"""
|
|
|
|
def ignore_aliases(self, data):
|
|
return True
|
|
|
|
|
|
class Job:
|
|
def __init__(self, name: str) -> None:
|
|
self.name = name
|
|
self.configs = dict()
|
|
|
|
def normalized_key(self) -> str:
|
|
name = str(self.name)
|
|
# Remove ${foo.bar} with a regex
|
|
name = re.sub(r"\${.*?}", "", name)
|
|
# Remove leading and trailing underscores
|
|
name = re.sub(r"^_+|_+$", "", name)
|
|
return name
|
|
|
|
def register_configuration(self, environment: str, config: Dict[str, any]) -> None:
|
|
self.configs[environment] = config
|
|
|
|
def all_equal(self) -> bool:
|
|
keys = list(self.configs.keys())
|
|
if len(keys) == 1:
|
|
return True
|
|
|
|
for i in range(1, len(keys)):
|
|
if not recursive_compare(self.configs[keys[i - 1]], self.configs[keys[i]]):
|
|
return False
|
|
|
|
return True
|
|
|
|
def compute_base(self) -> Dict[str, any]:
|
|
keys = list(self.configs.keys())
|
|
out = self.configs[keys[0]]
|
|
for key in keys[1:]:
|
|
out = recursive_intersection(out, self.configs[key])
|
|
return out
|
|
|
|
def compute_resource_definition(self) -> Dict[str, any]:
|
|
ordered_keys = [
|
|
"name",
|
|
"tags",
|
|
"schedule",
|
|
"email_notifications",
|
|
"git_source",
|
|
"permissions",
|
|
"tasks",
|
|
"job_clusters",
|
|
]
|
|
|
|
obj = self.compute_base()
|
|
|
|
# DBX uses "access_control_list" per job.
|
|
# In DABs we use "permissions" for the whole job.
|
|
obj["permissions"] = [dict(e) for e in obj["access_control_list"]]
|
|
for permission in obj["permissions"]:
|
|
permission["level"] = permission["permission_level"]
|
|
permission.pop("permission_level")
|
|
|
|
return {
|
|
"resources": {
|
|
"jobs": {self.normalized_key(): {k: obj[k] for k in ordered_keys}}
|
|
}
|
|
}
|
|
|
|
def compute_override_for_environment(self, environment: str) -> Dict[str, any]:
|
|
base = self.compute_base()
|
|
if environment not in self.configs:
|
|
return {}
|
|
|
|
config = self.configs[environment]
|
|
override = recursive_subtract(config, base)
|
|
|
|
# If the configuration is the same as the base, we don't need to override.
|
|
if not override:
|
|
return {}
|
|
|
|
return {
|
|
"targets": {
|
|
environment: {
|
|
"resources": {
|
|
"jobs": {
|
|
self.normalized_key(): recursive_subtract(config, base)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
class LookupRewriter:
|
|
@dataclasses.dataclass
|
|
class RewriteType:
|
|
variable_name_suffix: str
|
|
object_type: str
|
|
|
|
_prefixes = {
|
|
"cluster://": RewriteType(
|
|
variable_name_suffix="cluster_id",
|
|
object_type="cluster",
|
|
),
|
|
"cluster-policy://": RewriteType(
|
|
variable_name_suffix="cluster_policy_id",
|
|
object_type="cluster_policy",
|
|
),
|
|
"instance-profile://": None,
|
|
"instance-pool://": RewriteType(
|
|
variable_name_suffix="instance_pool_id",
|
|
object_type="instance_pool",
|
|
),
|
|
"pipeline://": None,
|
|
"service-principal://": RewriteType(
|
|
variable_name_suffix="service_principal_id",
|
|
object_type="service_principal",
|
|
),
|
|
"warehouse://": RewriteType(
|
|
variable_name_suffix="warehouse_id",
|
|
object_type="warehouse",
|
|
),
|
|
"query://": None,
|
|
"dashboard://": None,
|
|
"alert://": None,
|
|
}
|
|
|
|
def __init__(self, job: Job) -> None:
|
|
"""
|
|
One instance per job.
|
|
We track all references by the env they appear in so we can differentiate between them if needed.
|
|
"""
|
|
self.job = job
|
|
self.variables = {}
|
|
|
|
def add(self, env: str):
|
|
def cb(path, obj):
|
|
if isinstance(obj, str):
|
|
for prefix in self._prefixes.keys():
|
|
if obj.startswith(prefix):
|
|
payload = obj.replace(prefix, "")
|
|
if prefix in self.variables[env]:
|
|
raise ValueError(
|
|
f"Duplicate variable reference for {prefix} in {env}"
|
|
)
|
|
self.variables[env][str(path)] = [prefix, payload]
|
|
break
|
|
return obj
|
|
|
|
self.variables[env] = dict()
|
|
walk(self.job.configs[env], cb)
|
|
|
|
def confirm_envs_are_idential(self) -> Dict[str, any]:
|
|
# Run a deep equal on the dicts for every env
|
|
keys = list(self.variables.keys())
|
|
first = self.variables[keys[0]]
|
|
for key in keys[1:]:
|
|
diff = recursive_subtract(self.variables[key], first)
|
|
if diff:
|
|
raise ValueError("Variable references differ between environments")
|
|
return first
|
|
|
|
def rewrite(self) -> Dict[str, any]:
|
|
"""
|
|
Returns variables of the form:
|
|
|
|
{
|
|
"etl_cluster_policy_id": {
|
|
"description": "<unknown>",
|
|
"lookup": {
|
|
"cluster_policy": "some_policy"
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
rewrites = dict()
|
|
variables = []
|
|
|
|
# Compile a list of variables and how to rewrite the existing instances
|
|
for path, (prefix, payload) in self.confirm_envs_are_idential().items():
|
|
rewrite = self._prefixes[prefix]
|
|
if rewrite is None:
|
|
raise ValueError(f"Unhandled prefix: {prefix}")
|
|
|
|
variable_name = (
|
|
f"{self.job.normalized_key()}_{rewrite.variable_name_suffix}"
|
|
)
|
|
|
|
# Add rewrite for the path
|
|
rewrites[path] = f"${{var.{variable_name}}}"
|
|
|
|
# Add variable for the lookup
|
|
variables.append(
|
|
{
|
|
"name": variable_name,
|
|
"lookup_type": rewrite.object_type,
|
|
"lookup_value": payload,
|
|
}
|
|
)
|
|
|
|
# Now rewrite the job configuration
|
|
def cb(path, obj):
|
|
rewrite = rewrites.get(str(path), None)
|
|
if rewrite is not None:
|
|
return rewrite
|
|
return obj
|
|
|
|
for env in self.job.configs.keys():
|
|
self.job.configs[env] = walk(copy.deepcopy(self.job.configs[env]), cb)
|
|
|
|
return variables
|
|
|
|
|
|
def dedup_variables(variables):
|
|
deduped = dict()
|
|
for v in variables:
|
|
if v not in deduped:
|
|
deduped[v] = None
|
|
return deduped.keys()
|
|
|
|
|
|
def save_databricks_yml(
|
|
base_path: Path, env_variables, var_variables, var_lookup_variables
|
|
):
|
|
env = jinja2.Environment(
|
|
loader=jinja2.FileSystemLoader(Path(__file__).parent.joinpath("templates"))
|
|
)
|
|
template = env.get_template("databricks.yml.j2")
|
|
|
|
base_name = base_path.name
|
|
dst = base_path.joinpath("databricks.yml")
|
|
print("Writing: ", dst)
|
|
with open(dst, "w") as f:
|
|
f.write(
|
|
template.render(
|
|
bundle_name=base_name,
|
|
env_variables=env_variables,
|
|
var_variables=var_variables,
|
|
var_lookup_variables=var_lookup_variables,
|
|
)
|
|
)
|
|
|
|
|
|
def compute_variables_for_environment(
|
|
environment: str, variables: Dict[str, any]
|
|
) -> Dict[str, any]:
|
|
return {
|
|
"targets": {
|
|
environment: {
|
|
"variables": variables,
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Generate Databricks configurations")
|
|
parser.add_argument("dir", help="Path to the DBX project")
|
|
parser.add_argument(
|
|
"-v", "--verbose", action="store_true", help="Print verbose output"
|
|
)
|
|
args = parser.parse_args()
|
|
verbose: bool = args.verbose
|
|
|
|
base_path = Path(args.dir)
|
|
loader = Loader(base_path)
|
|
envs = loader.detect_environments()
|
|
print("Detected environments:", envs)
|
|
|
|
env_variables = []
|
|
var_variables = []
|
|
var_lookup_variables = []
|
|
|
|
jobs: Dict[str, Job] = dict()
|
|
for env in envs:
|
|
config, env_proxy, var_proxy = loader.load_for_environment(env)
|
|
env_variables.extend(env_proxy.variables())
|
|
var_variables.extend(var_proxy.variables())
|
|
for workflow in config["workflows"]:
|
|
name = workflow["name"]
|
|
if name not in jobs:
|
|
jobs[name] = Job(name)
|
|
|
|
jobs[name].register_configuration(env, workflow)
|
|
|
|
# Locate variable lookups
|
|
for job in jobs.values():
|
|
lr = LookupRewriter(job)
|
|
for env in job.configs:
|
|
lr.add(env)
|
|
var_lookup_variables.extend(lr.rewrite())
|
|
|
|
for job in jobs.values():
|
|
base_job = job.compute_base()
|
|
|
|
if verbose:
|
|
print("Job:", job.name)
|
|
|
|
# Write job configuration to "./resources" directory
|
|
resource_path = base_path.joinpath("resources")
|
|
resource_path.mkdir(exist_ok=True)
|
|
|
|
dst = resource_path.joinpath(f"{job.normalized_key()}.yml")
|
|
print("Writing: ", dst)
|
|
with open(dst, "w") as f:
|
|
yaml.dump(
|
|
job.compute_resource_definition(),
|
|
f,
|
|
Dumper=VerboseSafeDumper,
|
|
sort_keys=False,
|
|
)
|
|
|
|
for environment, config in job.configs.items():
|
|
diff = recursive_subtract(config, base_job)
|
|
|
|
if verbose:
|
|
yaml.dump(
|
|
diff,
|
|
sys.stdout,
|
|
indent=2,
|
|
Dumper=VerboseSafeDumper,
|
|
sort_keys=False,
|
|
)
|
|
|
|
# Write variable definitions
|
|
env_variables = dedup_variables(env_variables)
|
|
var_variables = dedup_variables(var_variables)
|
|
save_databricks_yml(base_path, env_variables, var_variables, var_lookup_variables)
|
|
|
|
# Write resource overrides
|
|
for env in envs:
|
|
out = {}
|
|
for job in jobs.values():
|
|
out = recursive_merge(out, job.compute_override_for_environment(env))
|
|
|
|
if out:
|
|
dst = base_path.joinpath(f"conf/{env}/overrides.yml")
|
|
print("Writing: ", dst)
|
|
with open(dst, "w") as f:
|
|
yaml.dump(
|
|
out,
|
|
f,
|
|
Dumper=VerboseSafeDumper,
|
|
sort_keys=False,
|
|
)
|
|
|
|
# Write variable values
|
|
for env in envs:
|
|
variables = loader.load_variables_for_environment(env)
|
|
|
|
# Include only variables listed in "var_variables"
|
|
variables = {k: v for k, v in variables.items() if k in var_variables}
|
|
|
|
out = compute_variables_for_environment(env, variables)
|
|
if out:
|
|
dst = base_path.joinpath(f"conf/{env}/variables.yml")
|
|
print("Writing: ", dst)
|
|
with open(dst, "w") as f:
|
|
yaml.dump(
|
|
out,
|
|
f,
|
|
Dumper=VerboseSafeDumper,
|
|
sort_keys=False,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|