diff --git a/mesa/datacollection.py b/mesa/datacollection.py index dfa4fa92d86..18360891efa 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -36,7 +36,8 @@ * For collecting agent-level variables, agents must have a unique_id """ -from collections import defaultdict +import itertools +from operator import attrgetter import pandas as pd @@ -83,7 +84,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): self.agent_reporters = {} self.model_vars = {} - self.agent_vars = {} + self._agent_records = {} self.tables = {} if model_reporters is not None: @@ -123,7 +124,6 @@ def _new_agent_reporter(self, name, reporter): if type(reporter) is str: reporter = self._make_attribute_collector(reporter) self.agent_reporters[name] = reporter - self.agent_vars[name] = [] def _new_table(self, table_name, table_columns): """ Add a new table that objects can write to. @@ -136,6 +136,21 @@ def _new_table(self, table_name, table_columns): new_table = {column: [] for column in table_columns} self.tables[table_name] = new_table + def _record_agents(self, model): + """ Record agents data in a mapping of functions and agents. """ + rep_funcs = self.agent_reporters.values() + if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]): + prefix = ['model.schedule.steps', 'unique_id'] + attributes = [func.attribute_name for func in rep_funcs] + get_reports = attrgetter(*prefix + attributes) + else: + def get_reports(agent): + prefix = (agent.model.schedule.steps, agent.unique_id) + reports = tuple(rep(agent) for rep in rep_funcs) + return prefix + reports + agent_records = map(get_reports, model.schedule.agents) + return agent_records + def collect(self, model): """ Collect all the data for the given model object. """ if self.model_reporters: @@ -143,11 +158,8 @@ def collect(self, model): self.model_vars[var].append(reporter(model)) if self.agent_reporters: - for var, reporter in self.agent_reporters.items(): - agent_records = [] - for agent in model.schedule.agents: - agent_records.append((agent.unique_id, reporter(agent))) - self.agent_vars[var].append(agent_records) + agent_records = self._record_agents(model) + self._agent_records[model.schedule.steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """ Add a row dictionary to a specific table. @@ -176,6 +188,14 @@ def _make_attribute_collector(attr): Create a function which collects the value of a named attribute ''' + def attribute_name(attr): + """Decorator to add attribute name as a function attribute.""" + def wrapper(func): + func.attribute_name = attr + return func + return wrapper + + @attribute_name(attr) def attr_collector(obj): return getattr(obj, attr) @@ -197,15 +217,15 @@ def get_agent_vars_dataframe(self): columns for tick and agent_id. """ - data = defaultdict(dict) - for var, records in self.agent_vars.items(): - for step, entries in enumerate(records): - for entry in entries: - agent_id = entry[0] - val = entry[1] - data[(step, agent_id)][var] = val - df = pd.DataFrame.from_dict(data, orient="index") - df.index.names = ["Step", "AgentID"] + all_records = itertools.chain.from_iterable( + self._agent_records.values()) + rep_names = [rep_name for rep_name in self.agent_reporters] + + df = pd.DataFrame.from_records( + data=all_records, + columns=["Step", "AgentID"] + rep_names, + ) + df = df.set_index(["Step", "AgentID"]) return df def get_table_dataframe(self, table_name): diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 6d572115155..1a6a748a808 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -83,18 +83,16 @@ def test_model_vars(self): for element in data_collector.model_vars["model_value"]: assert element == 100 - def test_agent_vars(self): + def test_agent_records(self): ''' Test agent-level variable collection. ''' data_collector = self.model.datacollector - assert len(data_collector.agent_vars["value"]) == 7 - assert len(data_collector.agent_vars["value2"]) == 7 - for var in ["value", "value2"]: - for step in data_collector.agent_vars[var]: - assert len(step) == 10 - for record in step: - assert len(record) == 2 + assert len(data_collector._agent_records) == 7 + for step, records in data_collector._agent_records.items(): + assert len(records) == 10 + for values in records: + assert len(values) == 4 def test_table_rows(self): '''