Use local Terraform state only when lineage match (#1588)

## Changes
DABs deployments should be isolated if `root_path` and workspace host
are different. This PR fixes a bug where local terraform state gets
piggybacked if the same cwd is used to deploy two isolated deployments
for the same bundle target. This can happen if:
1. A user switches to a different identity on the same machine. 
2. The workspace host URL the bundle/target points to is changed.
3. A user changes the `root_path` while doing bundle development.

To solve this problem we rely on the lineage field available in the
terraform state, which is a uuid identifying unique terraform
deployments. There's a 1:1 mapping between a terraform deployment and a
bundle deployment.

For more details on how lineage works in terraform, see:
https://developer.hashicorp.com/terraform/language/state/backends#manual-state-pull-push

## Tests
Manually verified that changing the identity no longer results in the
incorrect terraform state being used. Also, new unit tests are added.
This commit is contained in:
shreyas-goenka 2024-07-18 15:17:59 +05:30 committed by GitHub
parent af0114a5a6
commit 5b65358146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 194 additions and 188 deletions

View File

@ -1,8 +1,8 @@
package terraform package terraform
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"io" "io"
"io/fs" "io/fs"
@ -12,10 +12,14 @@ import (
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/deploy" "github.com/databricks/cli/bundle/deploy"
"github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/log" "github.com/databricks/cli/libs/log"
) )
type tfState struct {
Serial int64 `json:"serial"`
Lineage string `json:"lineage"`
}
type statePull struct { type statePull struct {
filerFactory deploy.FilerFactory filerFactory deploy.FilerFactory
} }
@ -24,74 +28,105 @@ func (l *statePull) Name() string {
return "terraform:state-pull" return "terraform:state-pull"
} }
func (l *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buffer, error) { func (l *statePull) remoteState(ctx context.Context, b *bundle.Bundle) (*tfState, []byte, error) {
// Download state file from filer to local cache directory. f, err := l.filerFactory(b)
remote, err := f.Read(ctx, TerraformStateFileName)
if err != nil { if err != nil {
// On first deploy this state file doesn't yet exist. return nil, nil, err
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, err
} }
defer remote.Close() r, err := f.Read(ctx, TerraformStateFileName)
if err != nil {
return nil, nil, err
}
defer r.Close()
var buf bytes.Buffer content, err := io.ReadAll(r)
_, err = io.Copy(&buf, remote) if err != nil {
return nil, nil, err
}
state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return nil, nil, err
}
return state, content, nil
}
func (l *statePull) localState(ctx context.Context, b *bundle.Bundle) (*tfState, error) {
dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &buf, nil content, err := os.ReadFile(filepath.Join(dir, TerraformStateFileName))
if err != nil {
return nil, err
}
state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return nil, err
}
return state, nil
} }
func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
f, err := l.filerFactory(b)
if err != nil {
return diag.FromErr(err)
}
dir, err := Dir(ctx, b) dir, err := Dir(ctx, b)
if err != nil { if err != nil {
return diag.FromErr(err) return diag.FromErr(err)
} }
// Download state file from filer to local cache directory. localStatePath := filepath.Join(dir, TerraformStateFileName)
log.Infof(ctx, "Opening remote state file")
remote, err := l.remoteState(ctx, f) // Case: Remote state file does not exist. In this case we fallback to using the
if err != nil { // local Terraform state. This allows users to change the "root_path" their bundle is
log.Infof(ctx, "Unable to open remote state file: %s", err) // configured with.
return diag.FromErr(err) remoteState, remoteContent, err := l.remoteState(ctx, b)
} if errors.Is(err, fs.ErrNotExist) {
if remote == nil { log.Infof(ctx, "Remote state file does not exist. Using local Terraform state.")
log.Infof(ctx, "Remote state file does not exist")
return nil return nil
} }
// Expect the state file to live under dir.
local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600)
if err != nil { if err != nil {
return diag.Errorf("failed to read remote state file: %v", err)
}
// Expected invariant: remote state file should have a lineage UUID. Error
// if that's not the case.
if remoteState.Lineage == "" {
return diag.Errorf("remote state file does not have a lineage")
}
// Case: Local state file does not exist. In this case we should rely on the remote state file.
localState, err := l.localState(ctx, b)
if errors.Is(err, fs.ErrNotExist) {
log.Infof(ctx, "Local state file does not exist. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err) return diag.FromErr(err)
} }
defer local.Close() if err != nil {
return diag.Errorf("failed to read local state file: %v", err)
if !IsLocalStateStale(local, bytes.NewReader(remote.Bytes())) {
log.Infof(ctx, "Local state is the same or newer, ignoring remote state")
return nil
} }
// Truncating the file before writing // If the lineage does not match, the Terraform state files do not correspond to the same deployment.
local.Truncate(0) if localState.Lineage != remoteState.Lineage {
local.Seek(0, 0) log.Infof(ctx, "Remote and local state lineages do not match. Using remote Terraform state. Invalidating local Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
// Write file to disk.
log.Infof(ctx, "Writing remote state file to local cache directory")
_, err = io.Copy(local, bytes.NewReader(remote.Bytes()))
if err != nil {
return diag.FromErr(err) return diag.FromErr(err)
} }
// If the remote state is newer than the local state, we should use the remote state.
if remoteState.Serial > localState.Serial {
log.Infof(ctx, "Remote state is newer than local state. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}
// default: local state is newer or equal to remote state in terms of serial sequence.
// It is also of the same lineage. Keep using the local state.
return nil return nil
} }

View File

@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
func mockStateFilerForPull(t *testing.T, contents map[string]int, merr error) filer.Filer { func mockStateFilerForPull(t *testing.T, contents map[string]any, merr error) filer.Filer {
buf, err := json.Marshal(contents) buf, err := json.Marshal(contents)
assert.NoError(t, err) assert.NoError(t, err)
@ -41,86 +41,123 @@ func statePullTestBundle(t *testing.T) *bundle.Bundle {
} }
} }
func TestStatePullLocalMissingRemoteMissing(t *testing.T) { func TestStatePullLocalErrorWhenRemoteHasNoLineage(t *testing.T) {
m := &statePull{ m := &statePull{}
identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist)),
} t.Run("no local state", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))
ctx := context.Background() ctx := context.Background()
b := statePullTestBundle(t) b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})
t.Run("local state with lineage", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))
ctx := context.Background()
b := statePullTestBundle(t)
writeLocalState(t, ctx, b, map[string]any{"serial": 5, "lineage": "aaaa"})
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})
}
func TestStatePullLocal(t *testing.T) {
tcases := []struct {
name string
// remote state before applying the pull mutators
remote map[string]any
// local state before applying the pull mutators
local map[string]any
// expected local state after applying the pull mutators
expected map[string]any
}{
{
name: "remote missing, local missing",
remote: nil,
local: nil,
expected: nil,
},
{
name: "remote missing, local present",
remote: nil,
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// fallback to local state, since remote state is missing.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local stale",
remote: map[string]any{"serial": 10, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use remote, since remote is newer.
expected: map[string]any{"serial": float64(10), "lineage": "aaaa", "some_other_key": float64(123)},
},
{
name: "local equal",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use local state, since they are equal in terms of serial sequence.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local newer",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 6, "lineage": "aaaa"},
// use local state, since local is newer.
expected: map[string]any{"serial": float64(6), "lineage": "aaaa"},
},
{
name: "remote and local have different lineages",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10, "lineage": "bbbb"},
// use remote, since lineages do not match.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local is missing lineage",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10},
// use remote, since local does not have lineage.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
}
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
m := &statePull{}
if tc.remote == nil {
// nil represents no remote state file.
m.filerFactory = identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist))
} else {
m.filerFactory = identityFiler(mockStateFilerForPull(t, tc.remote, nil))
}
ctx := context.Background()
b := statePullTestBundle(t)
if tc.local != nil {
writeLocalState(t, ctx, b, tc.local)
}
diags := bundle.Apply(ctx, b, m) diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error()) assert.NoError(t, diags.Error())
// Confirm that no local state file has been written. if tc.expected == nil {
// nil represents no local state file is expected.
_, err := os.Stat(localStateFile(t, ctx, b)) _, err := os.Stat(localStateFile(t, ctx, b))
assert.ErrorIs(t, err, fs.ErrNotExist) assert.ErrorIs(t, err, fs.ErrNotExist)
} } else {
func TestStatePullLocalMissingRemotePresent(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
}
ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b) localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState) assert.Equal(t, tc.expected, localState)
} }
})
func TestStatePullLocalStale(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
} }
ctx := context.Background()
b := statePullTestBundle(t)
// Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}
func TestStatePullLocalEqual(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
}
ctx := context.Background()
b := statePullTestBundle(t)
// Write a local state file with the same serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 5})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}
func TestStatePullLocalNewer(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
}
ctx := context.Background()
b := statePullTestBundle(t)
// Write a local state file with a newer serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 6})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 6}, localState)
} }

View File

@ -55,7 +55,7 @@ func TestStatePush(t *testing.T) {
b := statePushTestBundle(t) b := statePushTestBundle(t)
// Write a stale local state file. // Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4}) writeLocalState(t, ctx, b, map[string]any{"serial": 4})
diags := bundle.Apply(ctx, b, m) diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error()) assert.NoError(t, diags.Error())
} }

View File

@ -26,19 +26,19 @@ func localStateFile(t *testing.T, ctx context.Context, b *bundle.Bundle) string
return filepath.Join(dir, TerraformStateFileName) return filepath.Join(dir, TerraformStateFileName)
} }
func readLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle) map[string]int { func readLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle) map[string]any {
f, err := os.Open(localStateFile(t, ctx, b)) f, err := os.Open(localStateFile(t, ctx, b))
require.NoError(t, err) require.NoError(t, err)
defer f.Close() defer f.Close()
var contents map[string]int var contents map[string]any
dec := json.NewDecoder(f) dec := json.NewDecoder(f)
err = dec.Decode(&contents) err = dec.Decode(&contents)
require.NoError(t, err) require.NoError(t, err)
return contents return contents
} }
func writeLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle, contents map[string]int) { func writeLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle, contents map[string]any) {
f, err := os.Create(localStateFile(t, ctx, b)) f, err := os.Create(localStateFile(t, ctx, b))
require.NoError(t, err) require.NoError(t, err)
defer f.Close() defer f.Close()

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"io"
"os" "os"
"path/filepath" "path/filepath"
@ -22,10 +21,6 @@ type resourcesState struct {
const SupportedStateVersion = 4 const SupportedStateVersion = 4
type serialState struct {
Serial int `json:"serial"`
}
type stateResource struct { type stateResource struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
@ -41,34 +36,6 @@ type stateInstanceAttributes struct {
ID string `json:"id"` ID string `json:"id"`
} }
func IsLocalStateStale(local io.Reader, remote io.Reader) bool {
localState, err := loadState(local)
if err != nil {
return true
}
remoteState, err := loadState(remote)
if err != nil {
return false
}
return localState.Serial < remoteState.Serial
}
func loadState(input io.Reader) (*serialState, error) {
content, err := io.ReadAll(input)
if err != nil {
return nil, err
}
var s serialState
err = json.Unmarshal(content, &s)
if err != nil {
return nil, err
}
return &s, nil
}
func ParseResourcesState(ctx context.Context, b *bundle.Bundle) (*resourcesState, error) { func ParseResourcesState(ctx context.Context, b *bundle.Bundle) (*resourcesState, error) {
cacheDir, err := Dir(ctx, b) cacheDir, err := Dir(ctx, b)
if err != nil { if err != nil {

View File

@ -2,48 +2,15 @@ package terraform
import ( import (
"context" "context"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"testing/iotest"
"github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config" "github.com/databricks/cli/bundle/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLocalStateIsNewer(t *testing.T) {
local := strings.NewReader(`{"serial": 5}`)
remote := strings.NewReader(`{"serial": 4}`)
assert.False(t, IsLocalStateStale(local, remote))
}
func TestLocalStateIsOlder(t *testing.T) {
local := strings.NewReader(`{"serial": 5}`)
remote := strings.NewReader(`{"serial": 6}`)
assert.True(t, IsLocalStateStale(local, remote))
}
func TestLocalStateIsTheSame(t *testing.T) {
local := strings.NewReader(`{"serial": 5}`)
remote := strings.NewReader(`{"serial": 5}`)
assert.False(t, IsLocalStateStale(local, remote))
}
func TestLocalStateMarkStaleWhenFailsToLoad(t *testing.T) {
local := iotest.ErrReader(fmt.Errorf("Random error"))
remote := strings.NewReader(`{"serial": 5}`)
assert.True(t, IsLocalStateStale(local, remote))
}
func TestLocalStateMarkNonStaleWhenRemoteFailsToLoad(t *testing.T) {
local := strings.NewReader(`{"serial": 5}`)
remote := iotest.ErrReader(fmt.Errorf("Random error"))
assert.False(t, IsLocalStateStale(local, remote))
}
func TestParseResourcesStateWithNoFile(t *testing.T) { func TestParseResourcesStateWithNoFile(t *testing.T) {
b := &bundle.Bundle{ b := &bundle.Bundle{
RootPath: t.TempDir(), RootPath: t.TempDir(),