Define and use `testutil.TestingT` interface (#2003)

## Changes

Using an interface instead of a concrete type means we can pass
`*testing.T` directly or any wrapper type that implements a superset of
this interface. It prepares for more broad use of `acc.WorkspaceT`,
which enhances the testing object with helper functions for using a
Databricks workspace.

This eliminates the need to dereference a `*testing.T` field on a
wrapper type.

## Tests

n/a
This commit is contained in:
Pieter Noordhuis 2024-12-12 15:42:15 +01:00 committed by GitHub
parent cabdabf31e
commit dd3b7ec450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 134 additions and 83 deletions

View File

@ -6,7 +6,8 @@ import (
"path"
"path/filepath"
"strings"
"testing"
"github.com/databricks/cli/internal/testutil"
)
// Detects if test is run from "debug test" feature in VS Code.
@ -16,7 +17,7 @@ func isInDebug() bool {
}
// Loads debug environment from ~/.databricks/debug-env.json.
func loadDebugEnvIfRunFromIDE(t *testing.T, key string) {
func loadDebugEnvIfRunFromIDE(t testutil.TestingT, key string) {
if !isInDebug() {
return
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"testing"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/databricks-sdk-go"
@ -15,7 +14,7 @@ import (
)
type WorkspaceT struct {
*testing.T
testutil.TestingT
W *databricks.WorkspaceClient
@ -24,7 +23,7 @@ type WorkspaceT struct {
exec *compute.CommandExecutorV2
}
func WorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
func WorkspaceTest(t testutil.TestingT) (context.Context, *WorkspaceT) {
loadDebugEnvIfRunFromIDE(t, "workspace")
t.Log(testutil.GetEnvOrSkipTest(t, "CLOUD_ENV"))
@ -33,7 +32,7 @@ func WorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
require.NoError(t, err)
wt := &WorkspaceT{
T: t,
TestingT: t,
W: w,
@ -44,7 +43,7 @@ func WorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
}
// Run the workspace test only on UC workspaces.
func UcWorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
func UcWorkspaceTest(t testutil.TestingT) (context.Context, *WorkspaceT) {
loadDebugEnvIfRunFromIDE(t, "workspace")
t.Log(testutil.GetEnvOrSkipTest(t, "CLOUD_ENV"))
@ -60,7 +59,7 @@ func UcWorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
require.NoError(t, err)
wt := &WorkspaceT{
T: t,
TestingT: t,
W: w,
@ -71,7 +70,7 @@ func UcWorkspaceTest(t *testing.T) (context.Context, *WorkspaceT) {
}
func (t *WorkspaceT) TestClusterID() string {
clusterID := testutil.GetEnvOrSkipTest(t.T, "TEST_BRICKS_CLUSTER_ID")
clusterID := testutil.GetEnvOrSkipTest(t, "TEST_BRICKS_CLUSTER_ID")
err := t.W.Clusters.EnsureClusterIsRunning(t.ctx, clusterID)
require.NoError(t, err)
return clusterID

View File

@ -16,7 +16,7 @@ import (
func TestAccDeployBundleWithCluster(t *testing.T) {
ctx, wt := acc.WorkspaceTest(t)
if testutil.IsAWSCloud(wt.T) {
if testutil.IsAWSCloud(wt) {
t.Skip("Skipping test for AWS cloud because it is not permitted to create clusters")
}

View File

@ -29,10 +29,10 @@ func TestAccDeployBasicToSharedWorkspacePath(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() {
err = destroyBundle(wt.T, ctx, bundleRoot)
require.NoError(wt.T, err)
err = destroyBundle(wt, ctx, bundleRoot)
require.NoError(wt, err)
})
err = deployBundle(wt.T, ctx, bundleRoot)
require.NoError(wt.T, err)
err = deployBundle(wt, ctx, bundleRoot)
require.NoError(wt, err)
}

View File

@ -9,11 +9,11 @@ import (
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/internal"
"github.com/databricks/cli/internal/testutil"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/env"
"github.com/databricks/cli/libs/filer"
@ -26,12 +26,12 @@ import (
const defaultSparkVersion = "13.3.x-snapshot-scala2.12"
func initTestTemplate(t *testing.T, ctx context.Context, templateName string, config map[string]any) (string, error) {
func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any) (string, error) {
bundleRoot := t.TempDir()
return initTestTemplateWithBundleRoot(t, ctx, templateName, config, bundleRoot)
}
func initTestTemplateWithBundleRoot(t *testing.T, ctx context.Context, templateName string, config map[string]any, bundleRoot string) (string, error) {
func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) (string, error) {
templateRoot := filepath.Join("bundles", templateName)
configFilePath, err := writeConfigFile(t, config)
@ -49,7 +49,7 @@ func initTestTemplateWithBundleRoot(t *testing.T, ctx context.Context, templateN
return bundleRoot, err
}
func writeConfigFile(t *testing.T, config map[string]any) (string, error) {
func writeConfigFile(t testutil.TestingT, config map[string]any) (string, error) {
bytes, err := json.Marshal(config)
if err != nil {
return "", err
@ -63,34 +63,34 @@ func writeConfigFile(t *testing.T, config map[string]any) (string, error) {
return filepath, err
}
func validateBundle(t *testing.T, ctx context.Context, path string) ([]byte, error) {
func validateBundle(t testutil.TestingT, ctx context.Context, path string) ([]byte, error) {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
c := internal.NewCobraTestRunnerWithContext(t, ctx, "bundle", "validate", "--output", "json")
stdout, _, err := c.Run()
return stdout.Bytes(), err
}
func mustValidateBundle(t *testing.T, ctx context.Context, path string) []byte {
func mustValidateBundle(t testutil.TestingT, ctx context.Context, path string) []byte {
data, err := validateBundle(t, ctx, path)
require.NoError(t, err)
return data
}
func unmarshalConfig(t *testing.T, data []byte) *bundle.Bundle {
func unmarshalConfig(t testutil.TestingT, data []byte) *bundle.Bundle {
bundle := &bundle.Bundle{}
err := json.Unmarshal(data, &bundle.Config)
require.NoError(t, err)
return bundle
}
func deployBundle(t *testing.T, ctx context.Context, path string) error {
func deployBundle(t testutil.TestingT, ctx context.Context, path string) error {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
c := internal.NewCobraTestRunnerWithContext(t, ctx, "bundle", "deploy", "--force-lock", "--auto-approve")
_, _, err := c.Run()
return err
}
func deployBundleWithArgs(t *testing.T, ctx context.Context, path string, args ...string) (string, string, error) {
func deployBundleWithArgs(t testutil.TestingT, ctx context.Context, path string, args ...string) (string, string, error) {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
args = append([]string{"bundle", "deploy"}, args...)
c := internal.NewCobraTestRunnerWithContext(t, ctx, args...)
@ -98,7 +98,7 @@ func deployBundleWithArgs(t *testing.T, ctx context.Context, path string, args .
return stdout.String(), stderr.String(), err
}
func deployBundleWithFlags(t *testing.T, ctx context.Context, path string, flags []string) error {
func deployBundleWithFlags(t testutil.TestingT, ctx context.Context, path string, flags []string) error {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
args := []string{"bundle", "deploy", "--force-lock"}
args = append(args, flags...)
@ -107,7 +107,7 @@ func deployBundleWithFlags(t *testing.T, ctx context.Context, path string, flags
return err
}
func runResource(t *testing.T, ctx context.Context, path, key string) (string, error) {
func runResource(t testutil.TestingT, ctx context.Context, path, key string) (string, error) {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
ctx = cmdio.NewContext(ctx, cmdio.Default())
@ -116,7 +116,7 @@ func runResource(t *testing.T, ctx context.Context, path, key string) (string, e
return stdout.String(), err
}
func runResourceWithParams(t *testing.T, ctx context.Context, path, key string, params ...string) (string, error) {
func runResourceWithParams(t testutil.TestingT, ctx context.Context, path, key string, params ...string) (string, error) {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
ctx = cmdio.NewContext(ctx, cmdio.Default())
@ -128,14 +128,14 @@ func runResourceWithParams(t *testing.T, ctx context.Context, path, key string,
return stdout.String(), err
}
func destroyBundle(t *testing.T, ctx context.Context, path string) error {
func destroyBundle(t testutil.TestingT, ctx context.Context, path string) error {
ctx = env.Set(ctx, "BUNDLE_ROOT", path)
c := internal.NewCobraTestRunnerWithContext(t, ctx, "bundle", "destroy", "--auto-approve")
_, _, err := c.Run()
return err
}
func getBundleRemoteRootPath(w *databricks.WorkspaceClient, t *testing.T, uniqueId string) string {
func getBundleRemoteRootPath(w *databricks.WorkspaceClient, t testutil.TestingT, uniqueId string) string {
// Compute root path for the bundle deployment
me, err := w.CurrentUser.Me(context.Background())
require.NoError(t, err)
@ -143,7 +143,7 @@ func getBundleRemoteRootPath(w *databricks.WorkspaceClient, t *testing.T, unique
return root
}
func blackBoxRun(t *testing.T, root string, args ...string) (stdout, stderr string) {
func blackBoxRun(t testutil.TestingT, root string, args ...string) (stdout, stderr string) {
gitRoot, err := folders.FindDirWithLeaf(".", ".git")
require.NoError(t, err)

View File

@ -57,7 +57,7 @@ func TestAccPythonWheelTaskDeployAndRunWithWrapper(t *testing.T) {
func TestAccPythonWheelTaskDeployAndRunOnInteractiveCluster(t *testing.T) {
_, wt := acc.WorkspaceTest(t)
if testutil.IsAWSCloud(wt.T) {
if testutil.IsAWSCloud(wt) {
t.Skip("Skipping test for AWS cloud because it is not permitted to create clusters")
}

View File

@ -122,7 +122,7 @@ func TestAccFilerRecursiveDelete(t *testing.T) {
for _, testCase := range []struct {
name string
f func(t *testing.T) (filer.Filer, string)
f func(t testutil.TestingT) (filer.Filer, string)
}{
{"local", setupLocalFiler},
{"workspace files", setupWsfsFiler},
@ -233,7 +233,7 @@ func TestAccFilerReadWrite(t *testing.T) {
for _, testCase := range []struct {
name string
f func(t *testing.T) (filer.Filer, string)
f func(t testutil.TestingT) (filer.Filer, string)
}{
{"local", setupLocalFiler},
{"workspace files", setupWsfsFiler},
@ -342,7 +342,7 @@ func TestAccFilerReadDir(t *testing.T) {
for _, testCase := range []struct {
name string
f func(t *testing.T) (filer.Filer, string)
f func(t testutil.TestingT) (filer.Filer, string)
}{
{"local", setupLocalFiler},
{"workspace files", setupWsfsFiler},

View File

@ -62,8 +62,8 @@ func assertTargetDir(t *testing.T, ctx context.Context, f filer.Filer) {
type cpTest struct {
name string
setupSource func(*testing.T) (filer.Filer, string)
setupTarget func(*testing.T) (filer.Filer, string)
setupSource func(testutil.TestingT) (filer.Filer, string)
setupTarget func(testutil.TestingT) (filer.Filer, string)
}
func copyTests() []cpTest {

View File

@ -18,7 +18,7 @@ import (
type fsTest struct {
name string
setupFiler func(t *testing.T) (filer.Filer, string)
setupFiler func(t testutil.TestingT) (filer.Filer, string)
}
var fsTests = []fsTest{

View File

@ -15,7 +15,6 @@ import (
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/databricks/cli/cmd/root"
@ -45,7 +44,7 @@ import (
// It ensures that the background goroutine terminates upon
// test completion through cancelling the command context.
type cobraTestRunner struct {
*testing.T
testutil.TestingT
args []string
stdout bytes.Buffer
@ -128,7 +127,7 @@ func (t *cobraTestRunner) WaitForTextPrinted(text string, timeout time.Duration)
}
func (t *cobraTestRunner) WaitForOutput(text string, timeout time.Duration) {
require.Eventually(t.T, func() bool {
require.Eventually(t, func() bool {
currentStdout := t.stdout.String()
currentErrout := t.stderr.String()
return strings.Contains(currentStdout, text) || strings.Contains(currentErrout, text)
@ -300,23 +299,25 @@ func (t *cobraTestRunner) RunAndParseJSON(v any) {
require.NoError(t, err)
}
func NewCobraTestRunner(t *testing.T, args ...string) *cobraTestRunner {
func NewCobraTestRunner(t testutil.TestingT, args ...string) *cobraTestRunner {
return &cobraTestRunner{
T: t,
TestingT: t,
ctx: context.Background(),
args: args,
}
}
func NewCobraTestRunnerWithContext(t *testing.T, ctx context.Context, args ...string) *cobraTestRunner {
func NewCobraTestRunnerWithContext(t testutil.TestingT, ctx context.Context, args ...string) *cobraTestRunner {
return &cobraTestRunner{
T: t,
TestingT: t,
ctx: ctx,
args: args,
}
}
func RequireSuccessfulRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer) {
func RequireSuccessfulRun(t testutil.TestingT, args ...string) (bytes.Buffer, bytes.Buffer) {
t.Logf("run args: [%s]", strings.Join(args, ", "))
c := NewCobraTestRunner(t, args...)
stdout, stderr, err := c.Run()
@ -324,7 +325,7 @@ func RequireSuccessfulRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buf
return stdout, stderr
}
func RequireErrorRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer, error) {
func RequireErrorRun(t testutil.TestingT, args ...string) (bytes.Buffer, bytes.Buffer, error) {
c := NewCobraTestRunner(t, args...)
stdout, stderr, err := c.Run()
require.Error(t, err)
@ -398,7 +399,7 @@ func GenerateWheelTasks(wheelPath string, versions []string, nodeTypeId string)
return tasks
}
func TemporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string {
func TemporaryWorkspaceDir(t testutil.TestingT, w *databricks.WorkspaceClient) string {
ctx := context.Background()
me, err := w.CurrentUser.Me(ctx)
require.NoError(t, err)
@ -425,7 +426,7 @@ func TemporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string {
return basePath
}
func TemporaryDbfsDir(t *testing.T, w *databricks.WorkspaceClient) string {
func TemporaryDbfsDir(t testutil.TestingT, w *databricks.WorkspaceClient) string {
ctx := context.Background()
path := fmt.Sprintf("/tmp/%s", testutil.RandomName("integration-test-dbfs-"))
@ -449,7 +450,7 @@ func TemporaryDbfsDir(t *testing.T, w *databricks.WorkspaceClient) string {
}
// Create a new UC volume in a catalog called "main" in the workspace.
func TemporaryUcVolume(t *testing.T, w *databricks.WorkspaceClient) string {
func TemporaryUcVolume(t testutil.TestingT, w *databricks.WorkspaceClient) string {
ctx := context.Background()
// Create a schema
@ -483,7 +484,7 @@ func TemporaryUcVolume(t *testing.T, w *databricks.WorkspaceClient) string {
return path.Join("/Volumes", "main", schema.Name, volume.Name)
}
func TemporaryRepo(t *testing.T, w *databricks.WorkspaceClient) string {
func TemporaryRepo(t testutil.TestingT, w *databricks.WorkspaceClient) string {
ctx := context.Background()
me, err := w.CurrentUser.Me(ctx)
require.NoError(t, err)
@ -522,7 +523,7 @@ func GetNodeTypeId(env string) string {
return "Standard_DS4_v2"
}
func setupLocalFiler(t *testing.T) (filer.Filer, string) {
func setupLocalFiler(t testutil.TestingT) (filer.Filer, string) {
t.Log(testutil.GetEnvOrSkipTest(t, "CLOUD_ENV"))
tmp := t.TempDir()
@ -532,7 +533,7 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) {
return f, path.Join(filepath.ToSlash(tmp))
}
func setupWsfsFiler(t *testing.T) (filer.Filer, string) {
func setupWsfsFiler(t testutil.TestingT) (filer.Filer, string) {
ctx, wt := acc.WorkspaceTest(t)
tmpdir := TemporaryWorkspaceDir(t, wt.W)
@ -549,7 +550,7 @@ func setupWsfsFiler(t *testing.T) (filer.Filer, string) {
return f, tmpdir
}
func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
func setupWsfsExtensionsFiler(t testutil.TestingT) (filer.Filer, string) {
_, wt := acc.WorkspaceTest(t)
tmpdir := TemporaryWorkspaceDir(t, wt.W)
@ -559,7 +560,7 @@ func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
return f, tmpdir
}
func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
func setupDbfsFiler(t testutil.TestingT) (filer.Filer, string) {
_, wt := acc.WorkspaceTest(t)
tmpDir := TemporaryDbfsDir(t, wt.W)
@ -569,7 +570,7 @@ func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
return f, path.Join("dbfs:/", tmpDir)
}
func setupUcVolumesFiler(t *testing.T) (filer.Filer, string) {
func setupUcVolumesFiler(t testutil.TestingT) (filer.Filer, string) {
t.Log(testutil.GetEnvOrSkipTest(t, "CLOUD_ENV"))
if os.Getenv("TEST_METASTORE_ID") == "" {

View File

@ -68,7 +68,7 @@ func TestAccSecretsPutSecretStringValue(tt *testing.T) {
key := "test-key"
value := "test-value\nwith-newlines\n"
stdout, stderr := RequireSuccessfulRun(t.T, "secrets", "put-secret", scope, key, "--string-value", value)
stdout, stderr := RequireSuccessfulRun(t, "secrets", "put-secret", scope, key, "--string-value", value)
assert.Empty(t, stdout)
assert.Empty(t, stderr)
@ -82,7 +82,7 @@ func TestAccSecretsPutSecretBytesValue(tt *testing.T) {
key := "test-key"
value := []byte{0x00, 0x01, 0x02, 0x03}
stdout, stderr := RequireSuccessfulRun(t.T, "secrets", "put-secret", scope, key, "--bytes-value", string(value))
stdout, stderr := RequireSuccessfulRun(t, "secrets", "put-secret", scope, key, "--bytes-value", string(value))
assert.Empty(t, stdout)
assert.Empty(t, stderr)

View File

@ -1,9 +1,5 @@
package testutil
import (
"testing"
)
type Cloud int
const (
@ -13,7 +9,7 @@ const (
)
// Implement [Requirement].
func (c Cloud) Verify(t *testing.T) {
func (c Cloud) Verify(t TestingT) {
if c != GetCloud(t) {
t.Skipf("Skipping %s-specific test", c)
}
@ -32,7 +28,7 @@ func (c Cloud) String() string {
}
}
func GetCloud(t *testing.T) Cloud {
func GetCloud(t TestingT) Cloud {
env := GetEnvOrSkipTest(t, "CLOUD_ENV")
switch env {
case "aws":
@ -50,6 +46,6 @@ func GetCloud(t *testing.T) Cloud {
return -1
}
func IsAWSCloud(t *testing.T) bool {
func IsAWSCloud(t TestingT) bool {
return GetCloud(t) == AWS
}

View File

@ -5,14 +5,13 @@ import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
// CopyDirectory copies the contents of a directory to another directory.
// The destination directory is created if it does not exist.
func CopyDirectory(t *testing.T, src, dst string) {
func CopyDirectory(t TestingT, src, dst string) {
err := filepath.WalkDir(src, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err

View File

@ -5,7 +5,6 @@ import (
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
@ -13,7 +12,7 @@ import (
// CleanupEnvironment sets up a pristine environment containing only $PATH and $HOME.
// The original environment is restored upon test completion.
// Note: use of this function is incompatible with parallel execution.
func CleanupEnvironment(t *testing.T) {
func CleanupEnvironment(t TestingT) {
// Restore environment when test finishes.
environ := os.Environ()
t.Cleanup(func() {
@ -41,7 +40,7 @@ func CleanupEnvironment(t *testing.T) {
// Changes into specified directory for the duration of the test.
// Returns the current working directory.
func Chdir(t *testing.T, dir string) string {
func Chdir(t TestingT, dir string) string {
// Prevent parallel execution when changing the working directory.
// t.Setenv automatically fails if t.Parallel is set.
t.Setenv("DO_NOT_RUN_IN_PARALLEL", "true")

View File

@ -3,12 +3,11 @@ package testutil
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TouchNotebook(t *testing.T, elems ...string) string {
func TouchNotebook(t TestingT, elems ...string) string {
path := filepath.Join(elems...)
err := os.MkdirAll(filepath.Dir(path), 0o755)
require.NoError(t, err)
@ -18,7 +17,7 @@ func TouchNotebook(t *testing.T, elems ...string) string {
return path
}
func Touch(t *testing.T, elems ...string) string {
func Touch(t TestingT, elems ...string) string {
path := filepath.Join(elems...)
err := os.MkdirAll(filepath.Dir(path), 0o755)
require.NoError(t, err)
@ -32,7 +31,7 @@ func Touch(t *testing.T, elems ...string) string {
}
// WriteFile writes content to a file.
func WriteFile(t *testing.T, path, content string) {
func WriteFile(t TestingT, path, content string) {
err := os.MkdirAll(filepath.Dir(path), 0o755)
require.NoError(t, err)
@ -47,7 +46,7 @@ func WriteFile(t *testing.T, path, content string) {
}
// ReadFile reads a file and returns its content as a string.
func ReadFile(t require.TestingT, path string) string {
func ReadFile(t TestingT, path string) string {
b, err := os.ReadFile(path)
require.NoError(t, err)

View File

@ -5,11 +5,10 @@ import (
"math/rand"
"os"
"strings"
"testing"
)
// GetEnvOrSkipTest proceeds with test only with that env variable.
func GetEnvOrSkipTest(t *testing.T, name string) string {
func GetEnvOrSkipTest(t TestingT, name string) string {
value := os.Getenv(name)
if value == "" {
t.Skipf("Environment variable %s is missing", name)

View File

@ -0,0 +1,27 @@
package testutil
// TestingT is an interface wrapper around *testing.T that provides the methods
// that are used by the test package to convey information about test failures.
//
// We use an interface so we can wrap *testing.T and provide additional functionality.
type TestingT interface {
Log(args ...any)
Logf(format string, args ...any)
Error(args ...any)
Errorf(format string, args ...any)
Fatal(args ...any)
Fatalf(format string, args ...any)
Skip(args ...any)
Skipf(format string, args ...any)
FailNow()
Cleanup(func())
Setenv(key, value string)
TempDir() string
}

View File

@ -5,12 +5,11 @@ import (
"context"
"os/exec"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func RequireJDK(t *testing.T, ctx context.Context, version string) {
func RequireJDK(t TestingT, ctx context.Context, version string) {
var stderr bytes.Buffer
cmd := exec.Command("javac", "-version")

View File

@ -1,18 +1,14 @@
package testutil
import (
"testing"
)
// Requirement is the interface for test requirements.
type Requirement interface {
Verify(t *testing.T)
Verify(t TestingT)
}
// Require should be called at the beginning of a test to ensure that all
// requirements are met before running the test.
// If any requirement is not met, the test will be skipped.
func Require(t *testing.T, requirements ...Requirement) {
func Require(t TestingT, requirements ...Requirement) {
for _, r := range requirements {
r.Verify(t)
}

View File

@ -0,0 +1,36 @@
package testutil_test
import (
"go/parser"
"go/token"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNoTestingImport checks that no file in the package imports the testing package.
// All exported functions must use the TestingT interface instead of *testing.T.
func TestNoTestingImport(t *testing.T) {
// Parse the package
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, ".", nil, parser.AllErrors)
require.NoError(t, err)
// Iterate through the files in the package
for _, pkg := range pkgs {
for _, file := range pkg.Files {
// Skip test files
if strings.HasSuffix(fset.Position(file.Pos()).Filename, "_test.go") {
continue
}
// Check the imports of each file
for _, imp := range file.Imports {
if imp.Path.Value == `"testing"` {
assert.Fail(t, "File imports the testing package", "File %s imports the testing package", fset.Position(file.Pos()).Filename)
}
}
}
}
}