databricks-cli/bundle/run/job.go

347 lines
9.5 KiB
Go

package run
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/run/output"
"github.com/databricks/cli/bundle/run/progress"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/fatih/color"
flag "github.com/spf13/pflag"
)
// JobOptions defines options for running a job.
type JobOptions struct {
dbtCommands []string
jarParams []string
notebookParams map[string]string
pipelineParams map[string]string
pythonNamedParams map[string]string
pythonParams []string
sparkSubmitParams []string
sqlParams map[string]string
}
func (o *JobOptions) Define(fs *flag.FlagSet) {
fs.StringSliceVar(&o.dbtCommands, "dbt-commands", nil, "A list of commands to execute for jobs with DBT tasks.")
fs.StringSliceVar(&o.jarParams, "jar-params", nil, "A list of parameters for jobs with Spark JAR tasks.")
fs.StringToStringVar(&o.notebookParams, "notebook-params", nil, "A map from keys to values for jobs with notebook tasks.")
fs.StringToStringVar(&o.pipelineParams, "pipeline-params", nil, "A map from keys to values for jobs with pipeline tasks.")
fs.StringToStringVar(&o.pythonNamedParams, "python-named-params", nil, "A map from keys to values for jobs with Python wheel tasks.")
fs.StringSliceVar(&o.pythonParams, "python-params", nil, "A list of parameters for jobs with Python tasks.")
fs.StringSliceVar(&o.sparkSubmitParams, "spark-submit-params", nil, "A list of parameters for jobs with Spark submit tasks.")
fs.StringToStringVar(&o.sqlParams, "sql-params", nil, "A map from keys to values for jobs with SQL tasks.")
}
func (o *JobOptions) validatePipelineParams() (*jobs.PipelineParams, error) {
if len(o.pipelineParams) == 0 {
return nil, nil
}
var defaultErr = fmt.Errorf("job run argument --pipeline-params only supports `full_refresh=<bool>`")
v, ok := o.pipelineParams["full_refresh"]
if !ok {
return nil, defaultErr
}
b, err := strconv.ParseBool(v)
if err != nil {
return nil, defaultErr
}
pipelineParams := &jobs.PipelineParams{
FullRefresh: b,
}
return pipelineParams, nil
}
func (o *JobOptions) toPayload(jobID int64) (*jobs.RunNow, error) {
pipelineParams, err := o.validatePipelineParams()
if err != nil {
return nil, err
}
payload := &jobs.RunNow{
JobId: jobID,
DbtCommands: o.dbtCommands,
JarParams: o.jarParams,
NotebookParams: o.notebookParams,
PipelineParams: pipelineParams,
PythonNamedParams: o.pythonNamedParams,
PythonParams: o.pythonParams,
SparkSubmitParams: o.sparkSubmitParams,
SqlParams: o.sqlParams,
}
return payload, nil
}
// Default timeout for waiting for a job run to complete.
var jobRunTimeout time.Duration = 24 * time.Hour
type jobRunner struct {
key
bundle *bundle.Bundle
job *resources.Job
}
func (r *jobRunner) Name() string {
if r.job == nil || r.job.JobSettings == nil {
return ""
}
return r.job.JobSettings.Name
}
func isFailed(task jobs.RunTask) bool {
return task.State.LifeCycleState == jobs.RunLifeCycleStateInternalError ||
(task.State.LifeCycleState == jobs.RunLifeCycleStateTerminated &&
task.State.ResultState == jobs.RunResultStateFailed)
}
func isSuccess(task jobs.RunTask) bool {
return task.State.LifeCycleState == jobs.RunLifeCycleStateTerminated &&
task.State.ResultState == jobs.RunResultStateSuccess
}
func (r *jobRunner) logFailedTasks(ctx context.Context, runId int64) {
w := r.bundle.WorkspaceClient()
red := color.New(color.FgRed).SprintFunc()
green := color.New(color.FgGreen).SprintFunc()
yellow := color.New(color.FgYellow).SprintFunc()
run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{
RunId: runId,
})
if err != nil {
log.Errorf(ctx, "failed to log job run. Error: %s", err)
return
}
if run.State.ResultState == jobs.RunResultStateSuccess {
return
}
for _, task := range run.Tasks {
if isSuccess(task) {
log.Infof(ctx, "task %s completed successfully", green(task.TaskKey))
} else if isFailed(task) {
taskInfo, err := w.Jobs.GetRunOutput(ctx, jobs.GetRunOutputRequest{
RunId: task.RunId,
})
if err != nil {
log.Errorf(ctx, "task %s failed. Unable to fetch error trace: %s", red(task.TaskKey), err)
continue
}
if progressLogger, ok := cmdio.FromContext(ctx); ok {
progressLogger.Log(progress.NewTaskErrorEvent(task.TaskKey, taskInfo.Error, taskInfo.ErrorTrace))
}
log.Errorf(ctx, "Task %s failed!\nError:\n%s\nTrace:\n%s",
red(task.TaskKey), taskInfo.Error, taskInfo.ErrorTrace)
} else {
log.Infof(ctx, "task %s is in state %s",
yellow(task.TaskKey), task.State.LifeCycleState)
}
}
}
func pullRunIdCallback(runId *int64) func(info *jobs.Run) {
return func(i *jobs.Run) {
if *runId == 0 {
*runId = i.RunId
}
}
}
func logDebugCallback(ctx context.Context, runId *int64) func(info *jobs.Run) {
var prevState *jobs.RunState
return func(i *jobs.Run) {
state := i.State
if state == nil {
return
}
// Log the job run URL as soon as it is available.
if prevState == nil {
log.Infof(ctx, "Run available at %s", i.RunPageUrl)
}
if prevState == nil || prevState.LifeCycleState != state.LifeCycleState {
log.Infof(ctx, "Run status: %s", i.State.LifeCycleState)
prevState = state
}
}
}
func logProgressCallback(ctx context.Context, progressLogger *cmdio.Logger) func(info *jobs.Run) {
var prevState *jobs.RunState
return func(i *jobs.Run) {
state := i.State
if state == nil {
return
}
if prevState == nil {
progressLogger.Log(progress.NewJobRunUrlEvent(i.RunPageUrl))
}
if prevState != nil && prevState.LifeCycleState == state.LifeCycleState &&
prevState.ResultState == state.ResultState {
return
} else {
prevState = state
}
event := &progress.JobProgressEvent{
Timestamp: time.Now(),
JobId: i.JobId,
RunId: i.RunId,
RunName: i.RunName,
State: *i.State,
}
// log progress events to stderr
progressLogger.Log(event)
// log progress events in using the default logger
log.Infof(ctx, event.String())
}
}
func (r *jobRunner) Run(ctx context.Context, opts *Options) (output.RunOutput, error) {
jobID, err := strconv.ParseInt(r.job.ID, 10, 64)
if err != nil {
return nil, fmt.Errorf("job ID is not an integer: %s", r.job.ID)
}
runId := new(int64)
err = r.convertPythonParams(opts)
if err != nil {
return nil, err
}
// construct request payload from cmd line flags args
req, err := opts.Job.toPayload(jobID)
if err != nil {
return nil, err
}
// Include resource key in logger.
ctx = log.NewContext(ctx, log.GetLogger(ctx).With("resource", r.Key()))
w := r.bundle.WorkspaceClient()
// gets the run id from inside Jobs.RunNowAndWait
pullRunId := pullRunIdCallback(runId)
// callback to log status updates to the universal log destination.
// Called on every poll request
logDebug := logDebugCallback(ctx, runId)
// callback to log progress events. Called on every poll request
progressLogger, ok := cmdio.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("no progress logger found")
}
logProgress := logProgressCallback(ctx, progressLogger)
waiter, err := w.Jobs.RunNow(ctx, *req)
if err != nil {
return nil, fmt.Errorf("cannot start job")
}
if opts.NoWait {
details, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{
RunId: waiter.RunId,
})
progressLogger.Log(progress.NewJobRunUrlEvent(details.RunPageUrl))
return nil, err
}
run, err := waiter.OnProgress(func(r *jobs.Run) {
pullRunId(r)
logDebug(r)
logProgress(r)
}).GetWithTimeout(jobRunTimeout)
if err != nil && runId != nil {
r.logFailedTasks(ctx, *runId)
}
if err != nil {
return nil, err
}
if run.State.LifeCycleState == jobs.RunLifeCycleStateSkipped {
log.Infof(ctx, "Run was skipped!")
return nil, fmt.Errorf("run skipped: %s", run.State.StateMessage)
}
switch run.State.ResultState {
// The run was canceled at user request.
case jobs.RunResultStateCanceled:
log.Infof(ctx, "Run was cancelled!")
return nil, fmt.Errorf("run canceled: %s", run.State.StateMessage)
// The task completed with an error.
case jobs.RunResultStateFailed:
log.Infof(ctx, "Run has failed!")
return nil, fmt.Errorf("run failed: %s", run.State.StateMessage)
// The task completed successfully.
case jobs.RunResultStateSuccess:
log.Infof(ctx, "Run has completed successfully!")
return output.GetJobOutput(ctx, r.bundle.WorkspaceClient(), *runId)
// The run was stopped after reaching the timeout.
case jobs.RunResultStateTimedout:
log.Infof(ctx, "Run has timed out!")
return nil, fmt.Errorf("run timed out: %s", run.State.StateMessage)
}
return nil, err
}
func (r *jobRunner) convertPythonParams(opts *Options) error {
if r.bundle.Config.Experimental != nil && !r.bundle.Config.Experimental.PythonWheelWrapper {
return nil
}
needConvert := false
for _, task := range r.job.Tasks {
if task.PythonWheelTask != nil {
needConvert = true
break
}
}
if !needConvert {
return nil
}
if len(opts.Job.pythonParams) == 0 {
return nil
}
if opts.Job.notebookParams == nil {
opts.Job.notebookParams = make(map[string]string)
}
if len(opts.Job.pythonParams) > 0 {
if _, ok := opts.Job.notebookParams["__python_params"]; ok {
return fmt.Errorf("can't use __python_params as notebook param, the name is reserved for internal use")
}
p, err := json.Marshal(opts.Job.pythonParams)
if err != nil {
return err
}
opts.Job.notebookParams["__python_params"] = string(p)
}
return nil
}