Skip to content

Commit

Permalink
Refactor LookupDict to avoid recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
vemel committed Oct 28, 2024
1 parent 1955c63 commit c87b1b5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 40 deletions.
79 changes: 40 additions & 39 deletions mypy_boto3_builder/utils/lookup_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,65 +16,66 @@ class LookupDict(Generic[_V]):
"""
Lookup dictionary to get values by multiple keys.
Supports "*" as a wildcard key in a lookup dict.
Arguments:
hash_map -- Initial hash map.
required_keys -- Number of required keys in th end of key tuple.
hash_map -- Lookup hash map.
"""

_ALL: ClassVar = ALL

def __init__(
self,
hash_map: Mapping[str, _T],
items: Mapping[str, _T],
) -> None:
self._hash_map = hash_map
self._lookup_hash_map: dict[tuple[str, ...], _V] = {}
self._items = items
self._lookup_items: dict[tuple[str, ...], _V] = {}
self._keys_len = 0
self._products: tuple[tuple[bool, ...], ...] = ()

@property
def _lookup(self) -> dict[tuple[str, ...], _V]:
if not self._lookup_hash_map:
self._lookup_hash_map = self._generate_lookup(
{str(k): v for k, v in self._hash_map.items()}, ()
)
self._keys_len = len(next(iter(self._lookup_hash_map)))
self._products = self._generate_products()

return self._lookup_hash_map

def _generate_products(self) -> tuple[tuple[bool, ...], ...]:
static_keys = [True for _i in range(self._keys_len)]
for keys in self._lookup_hash_map:
for i, key in enumerate(keys):
if key == self._ALL:
static_keys[i] = False

frozen_static_keys = tuple(static_keys)
if all(frozen_static_keys):
return (frozen_static_keys,)

shifting_keys_count = len([key for key in static_keys if not key])
products = tuple(itertools.product((True, False), repeat=shifting_keys_count))
return tuple(self._join_product(frozen_static_keys, product) for product in products)
if not self._lookup_items:
self._lookup_items = self._generate_lookup({str(k): v for k, v in self._items.items()})
self._keys_len = len(next(iter(self._lookup_items)))
self._products = tuple(self._generate_products())

return self._lookup_items

def _generate_products(self) -> Iterator[tuple[bool, ...]]:
static_keys = tuple(
not any(keys[i] == self._ALL for keys in self._lookup_items)
for i in range(self._keys_len)
)
if all(static_keys):
yield static_keys
return

shifting_keys_count = static_keys.count(False)
products = itertools.product((True, False), repeat=shifting_keys_count)
for product in products:
yield self._join_product(static_keys, product)

@staticmethod
def _join_product(static_keys: tuple[bool, ...], product: tuple[bool, ...]) -> tuple[bool, ...]:
product_iter = iter(product)
return tuple(True if key else next(product_iter) for key in static_keys)

def _generate_lookup(
self, hash_map: Mapping[str, _T], keys: tuple[str, ...]
) -> dict[tuple[str, ...], _V]:
def _generate_lookup(self, hash_map: Mapping[str, _T]) -> dict[tuple[str, ...], _V]:
result: dict[tuple[str, ...], _V] = {}
for key, value in hash_map.items():
new_keys = (*keys, key)
if isinstance(value, Mapping):
value = cast(Mapping[str, _T], value) # type: ignore
result.update(self._generate_lookup(value, new_keys))
else:
value_v = cast(_V, value) # type: ignore
result[new_keys] = value_v
stack: list[tuple[Mapping[str, _T], tuple[str, ...]]] = [(hash_map, ())]

while stack:
current_map, current_keys = stack.pop()
for key, value in current_map.items():
new_keys = (*current_keys, key)
if isinstance(value, Mapping):
value_t = cast(Mapping[str, _T], value)
stack.append((value_t, new_keys))
else:
value_v = cast(_V, value)
result[new_keys] = value_v

return result

def _iterate_lookup_keys(self, keys: tuple[str, ...]) -> Iterator[tuple[str, ...]]:
Expand Down
11 changes: 10 additions & 1 deletion tests/utils/test_lookup_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class TestLookupDict:
def test_init(self) -> None:
def test_get(self) -> None:
lookup_dict: LookupDict[int] = LookupDict(
{"test": {"one": {"child1": 1}, ALL: {"child2": 10}}, ALL: {"two": {"child3": 12}}}
)
Expand All @@ -14,3 +14,12 @@ def test_init(self) -> None:
assert lookup_dict.get("test", "two", "child2") == 10
assert lookup_dict.get("test2", "two", "child2") is None
assert lookup_dict.get("test2", "two", "child3") == 12

def test_static(self) -> None:
lookup_dict: LookupDict[int] = LookupDict(
{"test": {"one": {"child1": 1}, "two": {"child2": 10}}}
)
assert lookup_dict.get("test", "one", "child1") == 1
assert lookup_dict.get("test", "one", "child2") is None
assert lookup_dict.get("test", "two", "child1") is None
assert lookup_dict.get("test", "two", "child2") == 10

0 comments on commit c87b1b5

Please sign in to comment.