Skip to content

Commit

Permalink
Merge pull request #576 from Corvince/FastDC
Browse files Browse the repository at this point in the history
Faster agent attribute collection
  • Loading branch information
dmasad authored Sep 30, 2018
2 parents f1201ab + 3f4f26b commit 08d4d0e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
54 changes: 37 additions & 17 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
* For collecting agent-level variables, agents must have a unique_id
"""
from collections import defaultdict
import itertools
from operator import attrgetter
import pandas as pd


Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
self.agent_reporters = {}

self.model_vars = {}
self.agent_vars = {}
self._agent_records = {}
self.tables = {}

if model_reporters is not None:
Expand Down Expand Up @@ -123,7 +124,6 @@ def _new_agent_reporter(self, name, reporter):
if type(reporter) is str:
reporter = self._make_attribute_collector(reporter)
self.agent_reporters[name] = reporter
self.agent_vars[name] = []

def _new_table(self, table_name, table_columns):
""" Add a new table that objects can write to.
Expand All @@ -136,18 +136,30 @@ def _new_table(self, table_name, table_columns):
new_table = {column: [] for column in table_columns}
self.tables[table_name] = new_table

def _record_agents(self, model):
""" Record agents data in a mapping of functions and agents. """
rep_funcs = self.agent_reporters.values()
if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]):
prefix = ['model.schedule.steps', 'unique_id']
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(*prefix + attributes)
else:
def get_reports(agent):
prefix = (agent.model.schedule.steps, agent.unique_id)
reports = tuple(rep(agent) for rep in rep_funcs)
return prefix + reports
agent_records = map(get_reports, model.schedule.agents)
return agent_records

def collect(self, model):
""" Collect all the data for the given model object. """
if self.model_reporters:
for var, reporter in self.model_reporters.items():
self.model_vars[var].append(reporter(model))

if self.agent_reporters:
for var, reporter in self.agent_reporters.items():
agent_records = []
for agent in model.schedule.agents:
agent_records.append((agent.unique_id, reporter(agent)))
self.agent_vars[var].append(agent_records)
agent_records = self._record_agents(model)
self._agent_records[model.schedule.steps] = list(agent_records)

def add_table_row(self, table_name, row, ignore_missing=False):
""" Add a row dictionary to a specific table.
Expand Down Expand Up @@ -176,6 +188,14 @@ def _make_attribute_collector(attr):
Create a function which collects the value of a named attribute
'''

def attribute_name(attr):
"""Decorator to add attribute name as a function attribute."""
def wrapper(func):
func.attribute_name = attr
return func
return wrapper

@attribute_name(attr)
def attr_collector(obj):
return getattr(obj, attr)

Expand All @@ -197,15 +217,15 @@ def get_agent_vars_dataframe(self):
columns for tick and agent_id.
"""
data = defaultdict(dict)
for var, records in self.agent_vars.items():
for step, entries in enumerate(records):
for entry in entries:
agent_id = entry[0]
val = entry[1]
data[(step, agent_id)][var] = val
df = pd.DataFrame.from_dict(data, orient="index")
df.index.names = ["Step", "AgentID"]
all_records = itertools.chain.from_iterable(
self._agent_records.values())
rep_names = [rep_name for rep_name in self.agent_reporters]

df = pd.DataFrame.from_records(
data=all_records,
columns=["Step", "AgentID"] + rep_names,
)
df = df.set_index(["Step", "AgentID"])
return df

def get_table_dataframe(self, table_name):
Expand Down
14 changes: 6 additions & 8 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,16 @@ def test_model_vars(self):
for element in data_collector.model_vars["model_value"]:
assert element == 100

def test_agent_vars(self):
def test_agent_records(self):
'''
Test agent-level variable collection.
'''
data_collector = self.model.datacollector
assert len(data_collector.agent_vars["value"]) == 7
assert len(data_collector.agent_vars["value2"]) == 7
for var in ["value", "value2"]:
for step in data_collector.agent_vars[var]:
assert len(step) == 10
for record in step:
assert len(record) == 2
assert len(data_collector._agent_records) == 7
for step, records in data_collector._agent_records.items():
assert len(records) == 10
for values in records:
assert len(values) == 4

def test_table_rows(self):
'''
Expand Down

0 comments on commit 08d4d0e

Please sign in to comment.