Skip to content

Commit

Permalink
Compare AttributeDict equality with their hashes
Browse files Browse the repository at this point in the history
  • Loading branch information
pacrob committed Sep 21, 2023
1 parent ffec086 commit 755521a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
1 change: 1 addition & 0 deletions newsfragments/3014.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use hashes to compare equality of two ``AttributeDict``s
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import random
import re

from web3.datastructures import (
Expand All @@ -7,6 +8,82 @@
)


def generate_random_value(depth=0, max_depth=3, key_type=None):
if depth > max_depth:
return None

choice = random.choice(["str", "int", "bool", "list", "tuple", "dict"])

if choice == "str":
return "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5))
elif choice == "int":
return random.randint(0, 100)
elif choice == "bool":
return random.choice([True, False])
elif choice == "list":
return [
generate_random_value(depth + 1, key_type=key_type)
for _ in range(random.randint(1, 5))
]
elif choice == "tuple":
return tuple(
generate_random_value(depth + 1, key_type=key_type)
for _ in range(random.randint(1, 5))
)
elif choice == "dict":
return generate_random_dict(depth + 1, key_type=key_type)


def generate_random_dict(depth=0, max_depth=3, max_keys=5, key_type=None):
if not key_type:
key_type = random.choice(["str", "int", "float"])

if depth > max_depth:
return {}

result = {}
for _ in range(random.randint(1, max_keys)):
if key_type == "str":
key = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5))
elif key_type == "int":
key = random.randint(0, 100)
elif key_type == "float":
key = round(random.uniform(0, 100), 2)

value = generate_random_value(depth, max_depth, key_type)
result[key] = value

return AttributeDict.recursive(result)


def test_attribute_dict_raises_when_mutated():
test_attr_dict = AttributeDict(
{"address": "0x4CB06C43fcdABeA22541fcF1F856A6a296448B6c"}
)
with pytest.raises(TypeError):
# should raise when trying to edit an attribute with dict syntax
test_attr_dict["address"] = "0x6C8f2A135f6ed072DE4503Bd7C4999a1a17F824B"

with pytest.raises(TypeError):
# should raise when trying to edit an attribute with dot syntax
test_attr_dict.address = "0x6C8f2A135f6ed072DE4503Bd7C4999a1a17F824B"

with pytest.raises(TypeError):
# should raise when trying to add an attribute with dict syntax
test_attr_dict["cats"] = "0x6C8f2A135f6ed072DE4503Bd7C4999a1a17F824B"

with pytest.raises(TypeError):
# should raise when trying to add an attribute with dot syntax
test_attr_dict.cats = "0x6C8f2A135f6ed072DE4503Bd7C4999a1a17F824B"

with pytest.raises(TypeError):
# should raise when trying to delete an attribute with dict syntax
del test_attr_dict["address"]

with pytest.raises(TypeError):
# should raise when trying to delete an attribute with dot syntax
del test_attr_dict.address

@pytest.mark.parametrize(
"input,expected",
(
Expand Down Expand Up @@ -50,9 +127,18 @@
),
),
)
def test_tupleization_and_hashing(input, expected):
def test_tupleization_and_hashing_passing_defined(input, expected):
assert tupleize_lists_nested(input) == expected
assert hash(AttributeDict(input)) == hash(expected)
assert AttributeDict(input) == expected


def test_tupleization_and_hashing_passing_random():
for _ in range(1000):
random_dict = generate_random_dict()
tupleized_random_dict = tupleize_lists_nested(random_dict)
assert hash(random_dict) == hash(tupleized_random_dict)
assert random_dict == tupleized_random_dict


@pytest.mark.parametrize(
Expand Down
11 changes: 7 additions & 4 deletions web3/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -117,7 +118,9 @@ def __hash__(self) -> int:
return hash(tuple(sorted(tupleize_lists_nested(self).items())))

def __eq__(self, other: Any) -> bool:
if isinstance(other, Mapping):
if isinstance(other, AttributeDict):
return hash(self) == hash(other)
elif isinstance(other, Mapping):
return self.__dict__ == dict(other)
else:
return False
Expand All @@ -130,12 +133,12 @@ def tupleize_lists_nested(d: Mapping[TKey, TValue]) -> AttributeDict[TKey, TValu
Other unhashable types found will raise a TypeError
"""

def _to_tuple(lst: List[Any]) -> Any:
return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst)
def _to_tuple(value: Union[List[Any], Tuple[Any, ...]]) -> Any:
return tuple(_to_tuple(i) if isinstance(i, (list, tuple)) else i for i in value)

ret = dict()
for k, v in d.items():
if isinstance(v, List):
if isinstance(v, (list, tuple)):
ret[k] = _to_tuple(v)
elif isinstance(v, Mapping):
ret[k] = tupleize_lists_nested(v)
Expand Down

0 comments on commit 755521a

Please sign in to comment.