Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements __hash__ for collections #81

Merged
merged 14 commits into from
Aug 6, 2024
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ exclude = [

[tool.ruff]
line-length = 79

[tool.ruff.lint]
select = ["ALL"]
ignore = [
"A001", # It's fine to shadow builtins
Expand Down Expand Up @@ -129,17 +131,18 @@ ignore = [
"TD", # These TODO style rules are also silly
"UP007", # We support 3.8 + 3.9
]

[tool.ruff.lint.flake8-pytest-style]
mark-parentheses = false

[tool.ruff.flake8-quotes]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"

[tool.ruff.lint.isort]
combine-as-imports = true
from-first = true

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"noxfile.py" = ["ANN", "D100", "S101", "T201"]
"docs/*" = ["ANN", "D", "INP001"]
"tests/*" = ["ANN", "B018", "D", "PLR", "RUF012", "S", "SIM", "TRY"]
125 changes: 116 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use pyo3::exceptions::PyIndexError;
use pyo3::exceptions::{PyIndexError, PyTypeError};
use pyo3::pyclass::CompareOp;
use pyo3::types::{PyDict, PyIterator, PyTuple, PyType};
use pyo3::{exceptions::PyKeyError, types::PyMapping, types::PyTupleMethods};
use pyo3::{prelude::*, AsPyPointer, PyTypeInfo};
use rpds::{
HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync, Queue, QueueSync,
};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

fn hash_shuffle_bits(h: usize) -> usize {
((h ^ 89869747) ^ (h << 16)) * 3644798167
}

#[derive(Debug)]
struct Key {
Expand Down Expand Up @@ -178,6 +181,51 @@ impl HashTrieMapPy {
}
}

fn __hash__(&self, py: Python) -> PyResult<isize> {
// modified from https://github.com/python/cpython/blob/d69529d31ccd1510843cfac1ab53bb8cb027541f/Objects/setobject.c#L715

let mut hash_val = self
.inner
.iter()
.map(|(key, val)| {
let mut hasher = DefaultHasher::new();
let val_bound = val.bind(py);

let key_hash = key.hash;
let val_hash = val_bound.hash().map_err(|_| {
PyTypeError::new_err(format!(
"Unhashable type in HashTrieMap of key {}: {}",
key.inner
.bind(py)
.repr()
.and_then(|r| r.extract())
.unwrap_or("<repr> error".to_string()),
val_bound
.repr()
.and_then(|r| r.extract())
.unwrap_or("<repr> error".to_string())
))
})?;

hasher.write_isize(key_hash);
hasher.write_isize(val_hash);

Ok(hasher.finish() as usize)
})
.try_fold(0, |acc: usize, x: PyResult<usize>| {
PyResult::<usize>::Ok(acc ^ hash_shuffle_bits(x?))
})?;

// facto in the number of entries in the collection
hash_val ^= (self.inner.size() + 1) * 1927868237;

// dispense patterns in the hash value
hash_val ^= (hash_val >> 11) ^ (hash_val >> 25);
hash_val = hash_val * 69069 + 907133923;

Ok(hash_val as isize)
}

fn __reduce__(slf: PyRef<Self>) -> (Bound<'_, PyType>, (Vec<(Key, PyObject)>,)) {
(
HashTrieMapPy::type_object_bound(slf.py()),
Expand Down Expand Up @@ -775,6 +823,25 @@ impl HashTrieSetPy {
Ok(true)
}

fn __hash__(&self) -> PyResult<isize> {
// modified from https://github.com/python/cpython/blob/d69529d31ccd1510843cfac1ab53bb8cb027541f/Objects/setobject.c#L715

let mut hash_val = self
.inner
.iter()
.map(|k| k.hash as usize)
.fold(0, |acc: usize, x: usize| acc ^ hash_shuffle_bits(x));

// facto in the number of entries in the collection
hash_val ^= (self.inner.size() + 1) * 1927868237;

// dispense patterns in the hash value
hash_val ^= (hash_val >> 11) ^ (hash_val >> 25);
hash_val = hash_val * 69069 + 907133923;

Ok(hash_val as isize)
}

fn __lt__(slf: PyRef<'_, Self>, other: Bound<'_, PyAny>, py: Python) -> PyResult<bool> {
let abc = PyModule::import_bound(py, "collections.abc")?;
if !other.is_instance(&abc.getattr("Set")?)? || other.len()? <= slf.inner.size() {
Expand Down Expand Up @@ -1036,6 +1103,31 @@ impl ListPy {
}
}

fn __hash__(&self, py: Python) -> PyResult<u64> {
let mut hasher = DefaultHasher::new();

self.inner
.iter()
.enumerate()
.try_for_each(|(index, each)| {
each.bind(py)
.hash()
.map_err(|_| {
PyTypeError::new_err(format!(
"Unhashable type at {} element in List: {}",
index,
each.bind(py)
.repr()
.and_then(|r| r.extract())
.unwrap_or("<repr> error".to_string())
))
})
.map(|x| hasher.write_isize(x))
})?;

Ok(hasher.finish())
}

fn __iter__(slf: PyRef<'_, Self>) -> ListIterator {
ListIterator {
inner: slf.inner.clone(),
Expand Down Expand Up @@ -1178,12 +1270,27 @@ impl QueuePy {
}

fn __hash__(&self, py: Python<'_>) -> PyResult<u64> {
let hash = PyModule::import_bound(py, "builtins")?.getattr("hash")?;
let mut hasher = DefaultHasher::new();
for each in &self.inner {
let n: i64 = hash.call1((each.into_py(py),))?.extract()?;
hasher.write_i64(n);
}

self.inner
.iter()
.enumerate()
.try_for_each(|(index, each)| {
each.bind(py)
.hash()
.map_err(|_| {
PyTypeError::new_err(format!(
"Unhashable type at {} element in Queue: {}",
index,
each.bind(py)
.repr()
.and_then(|r| r.extract())
.unwrap_or("<repr> error".to_string())
))
})
.map(|x| hasher.write_isize(x))
})?;

Ok(hasher.finish())
}

Expand Down
31 changes: 21 additions & 10 deletions tests/test_hash_trie_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@

from rpds import HashTrieMap

HASH_MSG = "Not sure HashTrieMap implements Hash, it has mutable methods"


@pytest.mark.xfail(reason=HASH_MSG)
def test_instance_of_hashable():
assert isinstance(HashTrieMap(), abc.Hashable)

Expand Down Expand Up @@ -154,12 +151,22 @@ def test_overwrite_existing_element():
assert map2["a"] == 3


@pytest.mark.xfail(reason=HASH_MSG)
def test_hash():
x = HashTrieMap(a=1, b=2, c=3)
y = HashTrieMap(a=1, b=2, c=3)
def test_hashing():
o = object()

assert hash(x) == hash(y)
assert hash(HashTrieMap([(o, o), (1, o)])) == hash(
HashTrieMap([(o, o), (1, o)]),
)
assert hash(HashTrieMap([(o, o), (1, o)])) == hash(
HashTrieMap([(1, o), (o, o)]),
)
assert hash(HashTrieMap([(o, "foo")])) == hash(HashTrieMap([(o, "foo")]))
assert hash(HashTrieMap()) == hash(HashTrieMap([]))

assert hash(HashTrieMap({1: 2})) != hash(HashTrieMap({1: 3}))
assert hash(HashTrieMap({o: 1})) != hash(HashTrieMap({o: o}))
assert hash(HashTrieMap([])) != hash(HashTrieMap([(o, 1)]))
assert hash(HashTrieMap({1: 2, 3: 4})) != hash(HashTrieMap({1: 3, 2: 4}))


def test_same_hash_when_content_the_same_but_underlying_vector_size_differs():
Expand All @@ -183,13 +190,17 @@ def __hash__(self):
raise ValueError("I am not currently hashable.")


@pytest.mark.xfail(reason=HASH_MSG)
def test_map_does_not_hash_values_on_second_hash_invocation():
hashable = HashabilityControlled()
x = HashTrieMap(dict(el=hashable))
hash(x)

hashable.hashable = False
hash(x)
with pytest.raises(
TypeError,
match=r"Unhashable type in HashTrieMap of key 'el'",
):
hash(x)


def test_equal():
Expand Down
22 changes: 17 additions & 5 deletions tests/test_hash_trie_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

from rpds import HashTrieSet

HASH_MSG = "Not sure HashTrieSet implements Hash, it has mutable methods"


def test_key_is_tuple():
with pytest.raises(KeyError):
Expand All @@ -47,9 +45,23 @@ def test_key_is_not_tuple():
HashTrieSet().remove("asdf")


@pytest.mark.xfail(reason=HASH_MSG)
def test_supports_hash():
assert hash(HashTrieSet((1, 2))) == hash(HashTrieSet(1, 2))
def test_hashing():
o = object()

assert hash(HashTrieSet([o])) == hash(HashTrieSet([o]))
assert hash(HashTrieSet([o, o])) == hash(HashTrieSet([o, o]))
assert hash(HashTrieSet([])) == hash(HashTrieSet([]))
assert hash(HashTrieSet([1, 2])) == hash(HashTrieSet([1, 2]))
assert hash(HashTrieSet([1, 2])) == hash(HashTrieSet([2, 1]))
assert not (HashTrieSet([1, 2]) == HashTrieSet([1, 3]))
assert not (HashTrieSet([]) == HashTrieSet([o]))

assert hash(HashTrieSet([1, 2])) != hash(HashTrieSet([1, 3]))
assert hash(HashTrieSet([1, o])) != hash(HashTrieSet([1, 2]))
assert hash(HashTrieSet([1, 2])) != hash(HashTrieSet([2, 1, 3]))
assert not (HashTrieSet([o]) != HashTrieSet([o, o]))
assert not (HashTrieSet([o, o]) != HashTrieSet([o, o]))
assert not (HashTrieSet() != HashTrieSet([]))


def test_empty_truthiness():
Expand Down
20 changes: 16 additions & 4 deletions tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

from rpds import List

HASH_MSG = "Not sure List implements Hash, it has mutable methods"


def test_literalish_works():
assert List(1, 2, 3) == List([1, 2, 3])
Expand Down Expand Up @@ -100,10 +98,24 @@ def test_repr():
assert str(List([1, 2, 3])) in "List([1, 2, 3])"


@pytest.mark.xfail(reason=HASH_MSG)
def test_hashing():
assert hash(List([1, 2])) == hash(List([1, 2]))
o = object()

assert hash(List([o, o])) == hash(List([o, o]))
assert hash(List([o])) == hash(List([o]))
assert hash(List()) == hash(List([]))
assert not (hash(List([1, 2])) == hash(List([1, 3])))
assert not (hash(List([1, 2])) == hash(List([2, 1])))
assert not (hash(List([o])) == hash(List([o, o])))
assert not (hash(List([])) == hash(List([o])))

assert hash(List([1, 2])) != hash(List([1, 3]))
assert hash(List([1, 2])) != hash(List([2, 1]))
assert hash(List([o])) != hash(List([o, o]))
assert hash(List([])) != hash(List([o]))
assert not (hash(List([o, o])) != hash(List([o, o])))
assert not (hash(List([o])) != hash(List([o])))
assert not (hash(List([])) != hash(List([])))


def test_sequence():
Expand Down
2 changes: 0 additions & 2 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@

from rpds import Queue

HASH_MSG = "Not sure Queue implements Hash, it has mutable methods"


def test_literalish_works():
assert Queue(1, 2, 3) == Queue([1, 2, 3])
Expand Down