diff --git a/concurrent_swiss_map.go b/concurrent_swiss_map.go index ac87e23..6bf7553 100644 --- a/concurrent_swiss_map.go +++ b/concurrent_swiss_map.go @@ -79,6 +79,18 @@ func (m *CsMap[K, V]) Delete(key K) bool { return shard.items.DeleteWithHash(key, hashShardPair.hash) } +func (m *CsMap[K, V]) DeleteIf(key K, condition func(value V) bool) bool { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + defer shard.Unlock() + value, ok := shard.items.GetWithHash(key, hashShardPair.hash) + if ok && condition(value) { + return shard.items.DeleteWithHash(key, hashShardPair.hash) + } + return false +} + func (m *CsMap[K, V]) Load(key K) (V, bool) { hashShardPair := m.getShard(key) shard := hashShardPair.shard @@ -127,6 +139,18 @@ func (m *CsMap[K, V]) SetIfAbsent(key K, value V) { } } +func (m *CsMap[K, V]) SetIf(key K, conditionFn func(previousVale V, previousFound bool) (value V, set bool)) { + hashShardPair := m.getShard(key) + shard := hashShardPair.shard + shard.Lock() + defer shard.Unlock() + value, found := shard.items.GetWithHash(key, hashShardPair.hash) + value, ok := conditionFn(value, found) + if ok { + shard.items.PutWithHash(key, value, hashShardPair.hash) + } +} + func (m *CsMap[K, V]) SetIfPresent(key K, value V) { hashShardPair := m.getShard(key) shard := hashShardPair.shard diff --git a/concurrent_swiss_map_test.go b/concurrent_swiss_map_test.go index 21bd331..795ef99 100644 --- a/concurrent_swiss_map_test.go +++ b/concurrent_swiss_map_test.go @@ -52,7 +52,6 @@ func TestSetIfAbsent(t *testing.T) { t.Fatal("1 should be exist") } } - func TestSetIfPresent(t *testing.T) { myMap := csmap.Create[int, string]() myMap.SetIfPresent(1, "test") @@ -68,6 +67,65 @@ func TestSetIfPresent(t *testing.T) { } } +func TestSetIf(t *testing.T) { + myMap := csmap.Create[int, string]() + myMap.SetIf(1, func(previousVale string, previousFound bool) (value string, set bool) { + // operate like a SetIfAbsent... + if !previousFound { + return "test", true + } + return "", false + }) + value, _ := myMap.Load(1) + if value != "test" { + t.Fatal("value should test") + } + + myMap.SetIf(1, func(previousVale string, previousFound bool) (value string, set bool) { + // operate like a SetIfAbsent... + if !previousFound { + return "bad", true + } + return "", false + }) + value, _ = myMap.Load(1) + if value != "test" { + t.Fatal("value should test") + } +} + +func TestDeleteIf(t *testing.T) { + myMap := csmap.Create[int, string]() + myMap.Store(1, "test") + ok1 := myMap.DeleteIf(20, func(value string) bool { + t.Fatal("condition function should not have been called") + return false + }) + if ok1 { + t.Fatal("ok1 should be false") + } + + ok2 := myMap.DeleteIf(1, func(value string) bool { + if value != "test" { + t.Fatal("condition function arg should be tests") + } + return false // don't delete + }) + if ok2 { + t.Fatal("ok1 should be false") + } + + ok3 := myMap.DeleteIf(1, func(value string) bool { + if value != "test" { + t.Fatal("condition function arg should be tests") + } + return true // allow the delete + }) + if !ok3 { + t.Fatal("ok2 should be true") + } +} + func TestCount(t *testing.T) { myMap := csmap.Create[int, string]() myMap.SetIfAbsent(1, "test")