databricks-cli/internal/testcli/runner.go

307 lines
7.8 KiB
Go

package testcli
import (
"bufio"
"bytes"
"context"
"encoding/json"
"io"
"reflect"
"strings"
"sync"
"time"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/require"
"github.com/databricks/cli/cmd"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
)
// Helper for running the root command in the background.
// It ensures that the background goroutine terminates upon
// test completion through cancelling the command context.
type Runner struct {
testutil.TestingT
args []string
stdout bytes.Buffer
stderr bytes.Buffer
stdinR *io.PipeReader
stdinW *io.PipeWriter
ctx context.Context
// Line-by-line output.
// Background goroutines populate these channels by reading from stdout/stderr pipes.
StdoutLines <-chan string
StderrLines <-chan string
errch <-chan error
}
func consumeLines(ctx context.Context, wg *sync.WaitGroup, r io.Reader) <-chan string {
ch := make(chan string, 30000)
wg.Add(1)
go func() {
defer close(ch)
defer wg.Done()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
// We expect to be able to always send these lines into the channel.
// If we can't, it means the channel is full and likely there is a problem
// in either the test or the code under test.
select {
case <-ctx.Done():
return
case ch <- scanner.Text():
continue
default:
panic("line buffer is full")
}
}
}()
return ch
}
func (r *Runner) registerFlagCleanup(c *cobra.Command) {
// Find target command that will be run. Example: if the command run is `databricks fs cp`,
// target command corresponds to `cp`
targetCmd, _, err := c.Find(r.args)
if err != nil && strings.HasPrefix(err.Error(), "unknown command") {
// even if command is unknown, we can proceed
require.NotNil(r, targetCmd)
} else {
require.NoError(r, err)
}
// Force initialization of default flags.
// These are initialized by cobra at execution time and would otherwise
// not be cleaned up by the cleanup function below.
targetCmd.InitDefaultHelpFlag()
targetCmd.InitDefaultVersionFlag()
// Restore flag values to their original value on test completion.
targetCmd.Flags().VisitAll(func(f *pflag.Flag) {
v := reflect.ValueOf(f.Value)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// Store copy of the current flag value.
reset := reflect.New(v.Type()).Elem()
reset.Set(v)
r.Cleanup(func() {
v.Set(reset)
})
})
}
// Like [Runner.Eventually], but more specific
func (r *Runner) WaitForTextPrinted(text string, timeout time.Duration) {
r.Eventually(func() bool {
currentStdout := r.stdout.String()
return strings.Contains(currentStdout, text)
}, timeout, 50*time.Millisecond)
}
func (r *Runner) WaitForOutput(text string, timeout time.Duration) {
require.Eventually(r, func() bool {
currentStdout := r.stdout.String()
currentErrout := r.stderr.String()
return strings.Contains(currentStdout, text) || strings.Contains(currentErrout, text)
}, timeout, 50*time.Millisecond)
}
func (r *Runner) WithStdin() {
reader, writer := io.Pipe()
r.stdinR = reader
r.stdinW = writer
}
func (r *Runner) CloseStdin() {
if r.stdinW == nil {
panic("no standard input configured")
}
r.stdinW.Close()
}
func (r *Runner) SendText(text string) {
if r.stdinW == nil {
panic("no standard input configured")
}
_, err := r.stdinW.Write([]byte(text + "\n"))
if err != nil {
panic("Failed to to write to t.stdinW")
}
}
func (r *Runner) RunBackground() {
var stdoutR, stderrR io.Reader
var stdoutW, stderrW io.WriteCloser
stdoutR, stdoutW = io.Pipe()
stderrR, stderrW = io.Pipe()
ctx := cmdio.NewContext(r.ctx, &cmdio.Logger{
Mode: flags.ModeAppend,
Reader: bufio.Reader{},
Writer: stderrW,
})
cli := cmd.New(ctx)
cli.SetOut(stdoutW)
cli.SetErr(stderrW)
cli.SetArgs(r.args)
if r.stdinW != nil {
cli.SetIn(r.stdinR)
}
// Register cleanup function to restore flags to their original values
// once test has been executed. This is needed because flag values reside
// in a global singleton data-structure, and thus subsequent tests might
// otherwise interfere with each other
r.registerFlagCleanup(cli)
errch := make(chan error)
ctx, cancel := context.WithCancel(ctx)
// Tee stdout/stderr to buffers.
stdoutR = io.TeeReader(stdoutR, &r.stdout)
stderrR = io.TeeReader(stderrR, &r.stderr)
// Consume stdout/stderr line-by-line.
var wg sync.WaitGroup
r.StdoutLines = consumeLines(ctx, &wg, stdoutR)
r.StderrLines = consumeLines(ctx, &wg, stderrR)
// Run command in background.
go func() {
err := root.Execute(ctx, cli)
if err != nil {
r.Logf("Error running command: %s", err)
}
// Close pipes to signal EOF.
stdoutW.Close()
stderrW.Close()
// Wait for the [consumeLines] routines to finish now that
// the pipes they're reading from have closed.
wg.Wait()
if r.stdout.Len() > 0 {
// Make a copy of the buffer such that it remains "unread".
scanner := bufio.NewScanner(bytes.NewBuffer(r.stdout.Bytes()))
for scanner.Scan() {
r.Logf("[databricks stdout]: %s", scanner.Text())
}
}
if r.stderr.Len() > 0 {
// Make a copy of the buffer such that it remains "unread".
scanner := bufio.NewScanner(bytes.NewBuffer(r.stderr.Bytes()))
for scanner.Scan() {
r.Logf("[databricks stderr]: %s", scanner.Text())
}
}
// Reset context on command for the next test.
// These commands are globals so we have to clean up to the best of our ability after each run.
// See https://github.com/spf13/cobra/blob/a6f198b635c4b18fff81930c40d464904e55b161/command.go#L1062-L1066
//nolint:staticcheck // cobra sets the context and doesn't clear it
cli.SetContext(nil)
// Make caller aware of error.
errch <- err
close(errch)
}()
// Ensure command terminates upon test completion (success or failure).
r.Cleanup(func() {
// Signal termination of command.
cancel()
// Wait for goroutine to finish.
<-errch
})
r.errch = errch
}
func (r *Runner) Run() (bytes.Buffer, bytes.Buffer, error) {
r.RunBackground()
err := <-r.errch
return r.stdout, r.stderr, err
}
// Like [require.Eventually] but errors if the underlying command has failed.
func (r *Runner) Eventually(condition func() bool, waitFor, tick time.Duration, msgAndArgs ...any) {
ch := make(chan bool, 1)
timer := time.NewTimer(waitFor)
defer timer.Stop()
ticker := time.NewTicker(tick)
defer ticker.Stop()
// Kick off condition check immediately.
go func() { ch <- condition() }()
for tick := ticker.C; ; {
select {
case err := <-r.errch:
require.Fail(r, "Command failed", err)
return
case <-timer.C:
require.Fail(r, "Condition never satisfied", msgAndArgs...)
return
case <-tick:
tick = nil
go func() { ch <- condition() }()
case v := <-ch:
if v {
return
}
tick = ticker.C
}
}
}
func (r *Runner) RunAndExpectOutput(heredoc string) {
stdout, _, err := r.Run()
require.NoError(r, err)
require.Equal(r, cmdio.Heredoc(heredoc), strings.TrimSpace(stdout.String()))
}
func (r *Runner) RunAndParseJSON(v any) {
stdout, _, err := r.Run()
require.NoError(r, err)
err = json.Unmarshal(stdout.Bytes(), &v)
require.NoError(r, err)
}
func NewRunner(t testutil.TestingT, ctx context.Context, args ...string) *Runner {
return &Runner{
TestingT: t,
ctx: ctx,
args: args,
}
}
func RequireSuccessfulRun(t testutil.TestingT, ctx context.Context, args ...string) (bytes.Buffer, bytes.Buffer) {
t.Logf("run args: [%s]", strings.Join(args, ", "))
r := NewRunner(t, ctx, args...)
stdout, stderr, err := r.Run()
require.NoError(t, err)
return stdout, stderr
}
func RequireErrorRun(t testutil.TestingT, ctx context.Context, args ...string) (bytes.Buffer, bytes.Buffer, error) {
r := NewRunner(t, ctx, args...)
stdout, stderr, err := r.Run()
require.Error(t, err)
return stdout, stderr, err
}