Skip to content

Commit

Permalink
Use a min heap to reduce the time complexity in the process of findin…
Browse files Browse the repository at this point in the history
…g the minimum while adding a new item.
  • Loading branch information
xiaoqingwanga committed Sep 18, 2024
1 parent 68c33eb commit 878899f
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 46 deletions.
105 changes: 105 additions & 0 deletions min_heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2020 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ristretto

// Item interface for heap elements
type Comparable[T any] interface {
Less(other *T) bool
}

// MinHeap represents a min heap data structure
type MinHeap[T Comparable[T]] struct {
items []*T
}

// NewMinHeap creates a new min heap
func NewMinHeap[T Comparable[T]]() *MinHeap[T] {
return &MinHeap[T]{}
}

// Insert adds a new element to the heap
func (h *MinHeap[T]) Insert(item *T) {
h.items = append(h.items, item)
h.heapifyUp(len(h.items) - 1)
}

// Extract removes and returns the minimum element from the heap
func (h *MinHeap[T]) Extract() (*T, bool) {
if len(h.items) == 0 {
return nil, false
}

min := h.items[0]
last := len(h.items) - 1
h.items[0] = h.items[last]
h.items = h.items[:last]

if len(h.items) > 0 {
h.heapifyDown(0)
}

return min, true
}

// heapifyUp maintains the heap property by moving a node up
func (h *MinHeap[T]) heapifyUp(index int) {
for index > 0 {
parentIndex := (index - 1) / 2
if !(*h.items[index]).Less(h.items[parentIndex]) {
break
}
h.items[parentIndex], h.items[index] = h.items[index], h.items[parentIndex]
index = parentIndex
}
}

// heapifyDown maintains the heap property by moving a node down
func (h *MinHeap[T]) heapifyDown(index int) {
for {
smallest := index
leftChild := 2*index + 1
rightChild := 2*index + 2

if leftChild < len(h.items) && (*h.items[leftChild]).Less(h.items[smallest]) {
smallest = leftChild
}

if rightChild < len(h.items) && (*h.items[rightChild]).Less(h.items[smallest]) {
smallest = rightChild
}

if smallest == index {
break
}

h.items[index], h.items[smallest] = h.items[smallest], h.items[index]
index = smallest
}
}

// Peek returns the minimum element without removing it
func (h *MinHeap[T]) Peek() (*T, bool) {
if len(h.items) == 0 {
return nil, false
}
return h.items[0], true
}

// Size returns the number of elements in the heap
func (h *MinHeap[T]) Size() int {
return len(h.items)
}
44 changes: 44 additions & 0 deletions min_heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ristretto

import (
"testing"

"github.com/stretchr/testify/require"
)

type CacheItem struct {
Key uint64
Hits uint64
}

func (p CacheItem) Less(other *CacheItem) bool {
return p.Hits < other.Hits
}

func TestMinHeap(t *testing.T) {
heap := NewMinHeap[CacheItem]()

// Test insertion
heap.Insert(&CacheItem{100, 30})
heap.Insert(&CacheItem{200, 25})

peek, _ := heap.Peek()
require.Equal(t, uint64(25), peek.Hits, "Peek returned incorrect item")

heap.Insert(&CacheItem{300, 35})
heap.Insert(&CacheItem{400, 20})

require.Equalf(t, 4, heap.Size(), "Expected heap size 4, got %d", heap.Size())

// Test extraction
expectedHits := []uint64{20, 25, 30, 35}
for i, expectedHit := range expectedHits {
item, ok := heap.Extract()
require.Truef(t, ok, "Failed to extract item %d", i)
require.Equalf(t, expectedHit, item.Hits, "Expected hit %d, got %d", expectedHit, item.Hits)
}

// Test empty heap
_, ok := heap.Extract()
require.False(t, ok, "Expected false when extracting from empty heap")
}
61 changes: 30 additions & 31 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package ristretto

import (
"math"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -94,6 +93,11 @@ func (p *defaultPolicy[V]) CollectMetrics(metrics *Metrics) {
type policyPair struct {
key uint64
cost int64
hits int64
}

func (p policyPair) Less(other *policyPair) bool {
return p.hits <= other.hits
}

func (p *defaultPolicy[V]) processItems() {
Expand Down Expand Up @@ -160,45 +164,33 @@ func (p *defaultPolicy[V]) Add(key uint64, cost int64) ([]*Item[V], bool) {
// incHits is the hit count for the incoming item.
incHits := p.admit.Estimate(key)
// sample is the eviction candidate pool to be filled via random sampling.
// TODO: perhaps we should use a min heap here. Right now our time
// complexity is N for finding the min. Min heap should bring it down to
// O(lg N).
sample := make([]*policyPair, 0, lfuSample)
sample := NewMinHeap[policyPair]()
// As items are evicted they will be appended to victims.
victims := make([]*Item[V], 0)

// Delete victims until there's enough space or a minKey is found that has
// more hits than incoming item.
for ; room < 0; room = p.evict.roomLeft(cost) {
// Fill up empty slots in sample.
sample = p.evict.fillSample(sample)

// Find minimally used item in sample.
minKey, minHits, minId, minCost := uint64(0), int64(math.MaxInt64), 0, int64(0)
for i, pair := range sample {
// Look up hit count for sample key.
if hits := p.admit.Estimate(pair.key); hits < minHits {
minKey, minHits, minId, minCost = pair.key, hits, i, pair.cost
}
// Get samples for empty slots.
slots := p.evict.fetchTopUpSamples(sample.Size())
for _, slot := range slots {
slot.hits = p.admit.Estimate(slot.key)
sample.Insert(slot)
}

// If the incoming item isn't worth keeping in the policy, reject.
if incHits < minHits {
if peek, _ := sample.Peek(); incHits < peek.hits {
p.metrics.add(rejectSets, key, 1)
return victims, false
}

peek, _ := sample.Extract()
// Delete the victim from metadata.
p.evict.del(minKey)

// Delete the victim from sample.
sample[minId] = sample[len(sample)-1]
sample = sample[:len(sample)-1]
// Store victim in evicted victims slice.
p.evict.del(peek.key)
victims = append(victims, &Item[V]{
Key: minKey,
Key: peek.key,
Conflict: 0,
Cost: minCost,
Cost: peek.cost,
})
}

Expand Down Expand Up @@ -309,17 +301,24 @@ func (p *sampledLFU) roomLeft(cost int64) int64 {
return p.getMaxCost() - (p.used + cost)
}

func (p *sampledLFU) fillSample(in []*policyPair) []*policyPair {
if len(in) >= lfuSample {
return in
/**
* Fetches samples for empty slots when the upper limit quantity is lfuSample.
*/
func (p *sampledLFU) fetchTopUpSamples(size int) []*policyPair {
if size >= lfuSample {
return []*policyPair{}
}

gap := lfuSample - size
samples := make([]*policyPair, 0)

for key, cost := range p.keyCosts {
in = append(in, &policyPair{key, cost})
if len(in) >= lfuSample {
return in
samples = append(samples, &policyPair{key: key, cost: cost})
if len(samples) >= gap {
return samples
}
}
return in
return samples
}

func (p *sampledLFU) del(key uint64) {
Expand Down
26 changes: 11 additions & 15 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,24 +199,20 @@ func TestSampledLFURoom(t *testing.T) {
require.Equal(t, int64(6), e.roomLeft(4))
}

func TestSampledLFUSample(t *testing.T) {
func TestSampledLFUGet(t *testing.T) {
e := newSampledLFU(16)
e.add(4, 4)
e.add(5, 5)
sample := e.fillSample([]*policyPair{
{1, 1},
{2, 2},
{3, 3},
})
k := sample[len(sample)-1].key
require.Equal(t, 5, len(sample))
require.NotEqual(t, 1, k)
require.NotEqual(t, 2, k)
require.NotEqual(t, 3, k)
require.Equal(t, len(sample), len(e.fillSample(sample)))
e.del(5)
sample = e.fillSample(sample[:len(sample)-2])
require.Equal(t, 4, len(sample))
samples := e.fetchTopUpSamples(3)
require.Equal(t, 2, len(samples))
samples = e.fetchTopUpSamples(4)
require.Equal(t, 1, len(samples))
e.add(6, 6)
samples = e.fetchTopUpSamples(2)
require.Equal(t, 3, len(samples))
require.NotEqual(t, samples[1].key, samples[2].key)
require.NotEqual(t, samples[1].key, samples[0].key)
require.NotEqual(t, samples[0].key, samples[2].key)
}

func TestTinyLFUIncrement(t *testing.T) {
Expand Down

0 comments on commit 878899f

Please sign in to comment.