Skip to content

Commit

Permalink
Refactor interface for methods to return errors
Browse files Browse the repository at this point in the history
This commit make the `Store` methods more adapatable by returning
errors.
It also refactor `Resize` in order to do what the name implies.
Finally we change the custom error type to avoid returning objects that
can contain sensitive data.

Signed-off-by: Soule BA <[email protected]>
  • Loading branch information
souleb committed Jun 14, 2024
1 parent f133773 commit 61b728f
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 69 deletions.
67 changes: 36 additions & 31 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (c *Cache[T]) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return ErrClosed
return &Error{ErrCodeClosed, closeMsg}
}
c.janitor.stop <- true
c.closed = true
Expand All @@ -136,14 +136,14 @@ func (c *Cache[T]) Set(object T) error {
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return KeyError{object, err}
return &Error{ErrCodeInvalidArgument, err.Error()}
}

c.mu.Lock()
if c.closed {
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return KeyError{object, ErrClosed}
return &Error{ErrCodeClosed, closeMsg}
}
_, found := c.index[key]
if found {
Expand All @@ -162,7 +162,7 @@ func (c *Cache[T]) Set(object T) error {
}
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return KeyError{object, ErrFull}
return &Error{ErrCodeFull, fullMsg}
}

func (c *cache[T]) set(key string, object T) {
Expand Down Expand Up @@ -190,17 +190,17 @@ func (c *Cache[T]) Get(object T) (item T, exists bool, err error) {
lvs, err = c.labelsFunc(object, len(c.metrics.getExtraLabels()))
if err != nil {
recordRequest(c.metrics, StatusFailure)
return res, false, KeyError{object, err}
return res, false, &Error{ErrCodeInvalidArgument, err.Error()}
}
}
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return res, false, KeyError{object, err}
return res, false, &Error{ErrCodeInvalidArgument, err.Error()}
}
item, found, err := c.get(key)
if err != nil {
return res, false, KeyError{object, err}
return res, false, err
}
if !found {
recordEvent(c.metrics, CacheEventTypeMiss, lvs...)
Expand Down Expand Up @@ -232,7 +232,7 @@ func (c *cache[T]) get(key string) (T, bool, error) {
if c.closed {
c.mu.RUnlock()
recordRequest(c.metrics, StatusFailure)
return res, false, ErrClosed
return res, false, &Error{ErrCodeClosed, closeMsg}
}
item, found := c.index[key]
if !found {
Expand All @@ -259,13 +259,13 @@ func (c *Cache[T]) Delete(object T) error {
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return KeyError{object, err}
return &Error{ErrCodeInvalidArgument, err.Error()}
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return KeyError{object, ErrClosed}
return &Error{ErrCodeClosed, closeMsg}
}
if item, ok := c.index[key]; ok {
// set the item expiration to now
Expand All @@ -292,40 +292,45 @@ func (c *cache[T]) Clear() {
}

// ListKeys returns a slice of the keys in the cache.
// If the cache is closed, ListKeys returns nil.
func (c *cache[T]) ListKeys() []string {
func (c *cache[T]) ListKeys() ([]string, error) {
c.mu.RLock()
if c.closed {
c.mu.RUnlock()
recordRequest(c.metrics, StatusFailure)
return nil
return nil, &Error{ErrCodeClosed, closeMsg}
}
keys := make([]string, 0, len(c.index))
for k := range c.index {
keys = append(keys, k)
}
c.mu.RUnlock()
recordRequest(c.metrics, StatusSuccess)
return keys
return keys, nil
}

// Resize resizes the cache and returns the number of index removed.
// Size must be greater than zero.
func (c *cache[T]) Resize(size int) int {
func (c *cache[T]) Resize(size int) (int, error) {
if size <= 0 {
recordRequest(c.metrics, StatusFailure)
return 0
}
overflow := len(c.items) - size
if overflow <= 0 {
recordRequest(c.metrics, StatusSuccess)
return 0
return 0, &Error{ErrCodeInvalidArgument, "size must be greater than zero"}
}

c.mu.Lock()
overflow := len(c.items) - size
if c.closed {
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return 0
return 0, &Error{ErrCodeClosed, closeMsg}
}

// set the new capacity
c.capacity = size

if overflow <= 0 {
c.mu.Unlock()
recordRequest(c.metrics, StatusSuccess)
return 0, nil
}

if !c.sorted {
Expand All @@ -346,22 +351,22 @@ func (c *cache[T]) Resize(size int) int {
c.items = c.items[overflow:]
c.mu.Unlock()
recordRequest(c.metrics, StatusSuccess)
return overflow
return overflow, nil
}

// HasExpired returns true if the item has expired.
func (c *Cache[T]) HasExpired(object T) (bool, error) {
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return false, KeyError{object, err}
return false, &Error{ErrCodeInvalidArgument, err.Error()}
}

c.mu.RLock()
if c.closed {
c.mu.RUnlock()
recordRequest(c.metrics, StatusFailure)
return false, KeyError{object, ErrClosed}
return false, &Error{ErrCodeClosed, closeMsg}
}
item, ok := c.index[key]
if !ok {
Expand All @@ -386,20 +391,20 @@ func (c *Cache[T]) SetExpiration(object T, expiration time.Time) error {
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return KeyError{object, err}
return &Error{ErrCodeInvalidArgument, err.Error()}
}

c.mu.Lock()
if c.closed {
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return KeyError{object, ErrClosed}
return &Error{ErrCodeClosed, closeMsg}
}
item, ok := c.index[key]
if !ok {
c.mu.Unlock()
recordRequest(c.metrics, StatusFailure)
return KeyError{object, ErrNotFound}
return &Error{ErrCodeNotFound, notFoundMsg}
}
item.expiresAt = expiration
// mark the items as not sorted
Expand All @@ -416,19 +421,19 @@ func (c *Cache[T]) GetExpiration(object T) (time.Time, error) {
key, err := c.keyFunc(object)
if err != nil {
recordRequest(c.metrics, StatusFailure)
return time.Time{}, KeyError{object, err}
return time.Time{}, &Error{ErrCodeInvalidArgument, err.Error()}
}
c.mu.RLock()
if c.closed {
c.mu.RUnlock()
recordRequest(c.metrics, StatusFailure)
return time.Time{}, KeyError{object, ErrClosed}
return time.Time{}, &Error{ErrCodeClosed, closeMsg}
}
item, ok := c.index[key]
if !ok {
c.mu.RUnlock()
recordRequest(c.metrics, StatusSuccess)
return time.Time{}, KeyError{object, ErrNotFound}
return time.Time{}, &Error{ErrCodeNotFound, notFoundMsg}
}
if !item.expiresAt.IsZero() {
if item.expiresAt.Compare(time.Now()) < 0 {
Expand Down
17 changes: 14 additions & 3 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ func Test_Cache_deleteExpired(t *testing.T) {
}

time.Sleep(5 * time.Millisecond)
keys := cache.ListKeys()
keys, err := cache.ListKeys()
g.Expect(err).ToNot(HaveOccurred())
g.Expect(keys).To(ConsistOf(tt.nonExpiredKeys))
})
}
Expand Down Expand Up @@ -499,9 +500,17 @@ func Test_Cache_Resize(t *testing.T) {
g.Expect(err).ToNot(HaveOccurred())
}

deleted := cache.Resize(10)
deleted, err := cache.Resize(10)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(deleted).To(Equal(n - 10))
g.Expect(cache.ListKeys()).To(HaveLen(10))
g.Expect(cache.capacity).To(Equal(10))

deleted, err = cache.Resize(15)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(deleted).To(Equal(0))
g.Expect(cache.ListKeys()).To(HaveLen(10))
g.Expect(cache.capacity).To(Equal(15))
}

func TestCache_Concurrent(t *testing.T) {
Expand Down Expand Up @@ -533,12 +542,14 @@ func TestCache_Concurrent(t *testing.T) {
defer wg.Done()
<-run
_, _, _ = cache.Get(objmap[key])
_ = cache.SetExpiration(objmap[key], time.Now().Add(noExpiration))
}()
}
close(run)
wg.Wait()

keys := cache.ListKeys()
keys, err := cache.ListKeys()
g.Expect(err).ToNot(HaveOccurred())
g.Expect(len(keys)).To(Equal(len(objmap)))

for _, obj := range objmap {
Expand Down
65 changes: 47 additions & 18 deletions cache/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,58 @@ limitations under the License.

package cache

import (
"errors"
"fmt"
const (
closeMsg = "cache is closed"
fullMsg = "cache is full"
notFoundMsg = "object not found"
)

var (
// ErrNotFound is returned when an item is not found in the cache.
ErrNotFound = errors.New("not found")
ErrAlreadyExists = errors.New("already exists")
ErrClosed = errors.New("cache closed")
ErrFull = errors.New("cache full")
ErrExpired = errors.New("key has expired")
ErrNoRegisterer = errors.New("no prometheus registerer provided")
// ErrCode is a code that represents the type of error.
type ErrCode uint

const (
ErrCodeUnknown = iota + 1 // 0 is reserved
ErrCodeNotFound
ErrCodeAlreadyExists
ErrCodeClosed
ErrCodeExpired
ErrCodeFull
ErrCodeInvalidArgument
)

// KeyError will be returned any time a KeyFunc gives an error; it includes the object
// at fault.
type KeyError struct {
Value any
Err error
// Error will be returned for cache errors.
type Error struct {
StatusCode ErrCode
msg string
}

// Error gives a human-readable description of the error.
func (k KeyError) Error() string {
return fmt.Sprintf("couldn't create key for value %+v: %v", k.Value, k.Err)
func (e *Error) Error() string {
return e.msg
}

// Is compares the error with the target error.
// It returns true if the status code of the errors match.
func (e *Error) Is(target error) bool {
if err, ok := target.(*Error); ok {
return e.StatusCode == err.StatusCode
}
return false
}

// IsErrCode verifies if the error is of the given code.
func IsErrCode(err error, code ErrCode) bool {
if e, ok := err.(*Error); ok {
return e.StatusCode == code
}
return false
}

// StatusCodeFromError returns the status code of the error.
// If the error is not of type Error, it returns Unknown.
func StatusCodeFromError(err error) ErrCode {
if e, ok := err.(*Error); ok {
return e.StatusCode
}
return ErrCodeUnknown
}
Loading

0 comments on commit 61b728f

Please sign in to comment.