diff --git a/libs/auth/cache/cache.go b/libs/auth/cache/cache.go index 097353e7..7ca9e8b6 100644 --- a/libs/auth/cache/cache.go +++ b/libs/auth/cache/cache.go @@ -9,6 +9,7 @@ import ( type TokenCache interface { Store(key string, t *oauth2.Token) error Lookup(key string) (*oauth2.Token, error) + DeleteKey(key string) error } var tokenCache int diff --git a/libs/auth/cache/file.go b/libs/auth/cache/file.go index 38dfea9f..9e99070e 100644 --- a/libs/auth/cache/file.go +++ b/libs/auth/cache/file.go @@ -73,6 +73,29 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { return t, nil } +func (c *FileTokenCache) DeleteKey(key string) error { + err := c.load() + if errors.Is(err, fs.ErrNotExist) { + return ErrNotConfigured + } else if err != nil { + return fmt.Errorf("load: %w", err) + } + c.Version = tokenCacheVersion + if c.Tokens == nil { + c.Tokens = map[string]*oauth2.Token{} + } + _, ok := c.Tokens[key] + if !ok { + return ErrNotConfigured + } + delete(c.Tokens, key) + raw, err := json.MarshalIndent(c, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + return os.WriteFile(c.fileLocation, raw, ownerReadWrite) +} + func (c *FileTokenCache) location() (string, error) { home, err := os.UserHomeDir() if err != nil { diff --git a/libs/auth/cache/file_test.go b/libs/auth/cache/file_test.go index 3e4aae36..3d0801a5 100644 --- a/libs/auth/cache/file_test.go +++ b/libs/auth/cache/file_test.go @@ -103,3 +103,29 @@ func TestStoreOnDev(t *testing.T) { // macOS: read-only file system assert.Error(t, err) } + +func TestStoreAndDeleteKey(t *testing.T) { + setup(t) + c := &FileTokenCache{} + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + l := &FileTokenCache{} + err = l.DeleteKey("x") + require.NoError(t, err) + assert.Equal(t, 1, len(l.Tokens)) + + _, err = l.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) + + tok, err := l.Lookup("y") + require.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) +} diff --git a/libs/auth/cache/in_memory.go b/libs/auth/cache/in_memory.go index 469d4557..6daaf868 100644 --- a/libs/auth/cache/in_memory.go +++ b/libs/auth/cache/in_memory.go @@ -23,4 +23,14 @@ func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { return nil } +// DeleteKey implements TokenCache. +func (i *InMemoryTokenCache) DeleteKey(key string) error { + _, ok := i.Tokens[key] + if !ok { + return ErrNotConfigured + } + delete(i.Tokens, key) + return nil +} + var _ TokenCache = (*InMemoryTokenCache)(nil) diff --git a/libs/auth/cache/in_memory_test.go b/libs/auth/cache/in_memory_test.go index d8394d3b..f67bf8d3 100644 --- a/libs/auth/cache/in_memory_test.go +++ b/libs/auth/cache/in_memory_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) @@ -42,3 +43,29 @@ func TestInMemoryCacheStore(t *testing.T) { assert.Equal(t, res, token) assert.NoError(t, err) } + +func TestInMemoryDeleteKey(t *testing.T) { + c := &InMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{}, + } + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + err = c.DeleteKey("x") + require.NoError(t, err) + assert.Equal(t, 1, len(c.Tokens)) + + _, err = c.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) + + tok, err := c.Lookup("y") + require.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) +} diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go index ea6a8061..42cc0ef9 100644 --- a/libs/auth/oauth_test.go +++ b/libs/auth/oauth_test.go @@ -53,8 +53,9 @@ func TestOidcForWorkspace(t *testing.T) { } type tokenCacheMock struct { - store func(key string, t *oauth2.Token) error - lookup func(key string) (*oauth2.Token, error) + store func(key string, t *oauth2.Token) error + lookup func(key string) (*oauth2.Token, error) + deleteKey func(key string) error } func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { @@ -71,6 +72,13 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { return m.lookup(key) } +func (m *tokenCacheMock) DeleteKey(key string) error { + if m.deleteKey == nil { + panic("no deleteKey mock") + } + return m.deleteKey(key) +} + func TestLoad(t *testing.T) { p := &PersistentAuth{ Host: "abc",