Skip to content

Commit

Permalink
Update to CellCollection.select (#2307)
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored Sep 21, 2024
1 parent a686826 commit 56117dc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Cell:
def __init__(
self,
coordinate: Coordinate,
capacity: float | None = None,
capacity: int | None = None,
random: Random | None = None,
) -> None:
"""Initialise the cell.
Expand All @@ -65,7 +65,7 @@ def __init__(
self.agents: list[
Agent
] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
self.capacity = capacity
self.capacity: int = capacity
self.properties: dict[Coordinate, object] = {}
self.random = random

Expand Down
33 changes: 22 additions & 11 deletions mesa/experimental/cell_space/cell_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,36 @@ def select_random_agent(self) -> CellAgent:
"""
return self.random.choice(list(self.agents))

def select(self, filter_func: Callable[[T], bool] | None = None, n=0):
def select(
self,
filter_func: Callable[[T], bool] | None = None,
at_most: int | float = float("inf"),
):
"""Select cells based on filter function.
Args:
filter_func: filter function
n: number of cells to select
at_most: The maximum amount of cells to select. Defaults to infinity.
- If an integer, at most the first number of matching cells is selected.
- If a float between 0 and 1, at most that fraction of original number of cells
Returns:
CellCollection
"""
# FIXME: n is not considered
if filter_func is None and n == 0:
if filter_func is None and at_most == float("inf"):
return self

return CellCollection(
{
cell: agents
for cell, agents in self._cells.items()
if filter_func is None or filter_func(cell)
}
)
if at_most <= 1.0 and isinstance(at_most, float):
at_most = int(len(self) * at_most) # Note that it rounds down (floor)

def cell_generator(filter_func, at_most):
count = 0
for cell in self:
if count >= at_most:
break
if not filter_func or filter_func(cell):
yield cell
count += 1

return CellCollection(cell_generator(filter_func, at_most))
12 changes: 12 additions & 0 deletions tests/test_cell_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,15 @@ def test_cell_collection():

agents = collection[cells[0]]
assert agents == cells[0].agents

cell = collection.select(at_most=1)
assert len(cell) == 1

cells = collection.select(at_most=2)
assert len(cells) == 2

cells = collection.select(at_most=0.5)
assert len(cells) == 5

cells = collection.select()
assert len(cells) == len(collection)

0 comments on commit 56117dc

Please sign in to comment.