From 4ebbb0eec7dc067eac3690a3677aa2ff20d24a27 Mon Sep 17 00:00:00 2001 From: Corvince Date: Thu, 23 Aug 2018 10:26:25 +0200 Subject: [PATCH 1/5] Faster agent attribute collection Only significantly faster if all agent reporters are attributes --- mesa/datacollection.py | 34 ++++++++++++++++++++++------------ tests/test_datacollector.py | 12 +++++------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index dfa4fa92d86..108867ecdbb 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -37,6 +37,7 @@ """ from collections import defaultdict +from operator import attrgetter import pandas as pd @@ -91,8 +92,14 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): self._new_model_reporter(name, reporter) if agent_reporters is not None: - for name, reporter in agent_reporters.items(): - self._new_agent_reporter(name, reporter) + if all([type(rep) is str for rep in agent_reporters.values()]): + self.fast_collect = True + for name, reporter in agent_reporters.items(): + self.agent_reporters[name] = reporter + else: + self.fast_collect = False + for name, reporter in agent_reporters.items(): + self._new_agent_reporter(name, reporter) if tables is not None: for name, columns in tables.items(): @@ -123,7 +130,6 @@ def _new_agent_reporter(self, name, reporter): if type(reporter) is str: reporter = self._make_attribute_collector(reporter) self.agent_reporters[name] = reporter - self.agent_vars[name] = [] def _new_table(self, table_name, table_columns): """ Add a new table that objects can write to. @@ -143,11 +149,16 @@ def collect(self, model): self.model_vars[var].append(reporter(model)) if self.agent_reporters: - for var, reporter in self.agent_reporters.items(): - agent_records = [] + if self.fast_collect: + f = attrgetter(*self.agent_reporters.values()) + agent_records = { + agent.unique_id: f(agent) for agent + in model.schedule.agents} + else: + agent_records = {} for agent in model.schedule.agents: - agent_records.append((agent.unique_id, reporter(agent))) - self.agent_vars[var].append(agent_records) + agent_records[agent.unique_id] = tuple(rep(agent) for rep in self.agent_reporters.values()) + self.agent_vars[model.schedule.steps] = agent_records def add_table_row(self, table_name, row, ignore_missing=False): """ Add a row dictionary to a specific table. @@ -198,11 +209,10 @@ def get_agent_vars_dataframe(self): """ data = defaultdict(dict) - for var, records in self.agent_vars.items(): - for step, entries in enumerate(records): - for entry in entries: - agent_id = entry[0] - val = entry[1] + for step, records in self.agent_vars.items(): + for agent_id, vals in records.items(): + for i, val in enumerate(vals): + var = list(self.agent_reporters)[i] data[(step, agent_id)][var] = val df = pd.DataFrame.from_dict(data, orient="index") df.index.names = ["Step", "AgentID"] diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 6d572115155..4a2d44cea4c 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -88,13 +88,11 @@ def test_agent_vars(self): Test agent-level variable collection. ''' data_collector = self.model.datacollector - assert len(data_collector.agent_vars["value"]) == 7 - assert len(data_collector.agent_vars["value2"]) == 7 - for var in ["value", "value2"]: - for step in data_collector.agent_vars[var]: - assert len(step) == 10 - for record in step: - assert len(record) == 2 + assert len(data_collector.agent_vars) == 7 + for step, records in data_collector.agent_vars.items(): + assert len(records) == 10 + for values in records.values(): + assert len(values) == 2 def test_table_rows(self): ''' From aaa4c545fbaaeb7408c74148e96510421ea243fa Mon Sep 17 00:00:00 2001 From: Corvince Date: Fri, 24 Aug 2018 12:29:21 +0200 Subject: [PATCH 2/5] Lazy evaluation of reporter functions --- mesa/datacollection.py | 60 ++++++++++++++++++++++++------------- tests/test_datacollector.py | 11 +++---- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 108867ecdbb..c9f262c8d9f 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -84,7 +84,7 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): self.agent_reporters = {} self.model_vars = {} - self.agent_vars = {} + self._agent_records = {} self.tables = {} if model_reporters is not None: @@ -92,12 +92,6 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): self._new_model_reporter(name, reporter) if agent_reporters is not None: - if all([type(rep) is str for rep in agent_reporters.values()]): - self.fast_collect = True - for name, reporter in agent_reporters.items(): - self.agent_reporters[name] = reporter - else: - self.fast_collect = False for name, reporter in agent_reporters.items(): self._new_agent_reporter(name, reporter) @@ -142,6 +136,23 @@ def _new_table(self, table_name, table_columns): new_table = {column: [] for column in table_columns} self.tables[table_name] = new_table + @staticmethod + def _repgetter(reporters): + """ Get reports from a list of agents. + + Args: + reporters: List of reporter functions. + + """ + if all([rep.__name__ == 'attr_collector' for rep in reporters]): + # Fast path if all reporters are attribute collecters + report = attrgetter(*[rep.attribute_name for rep in reporters]) + else: + def report(agent): + return (agent.unique_id, ) + tuple( + get_report(rep, agent) for rep in reporters) + return report + def collect(self, model): """ Collect all the data for the given model object. """ if self.model_reporters: @@ -149,16 +160,9 @@ def collect(self, model): self.model_vars[var].append(reporter(model)) if self.agent_reporters: - if self.fast_collect: - f = attrgetter(*self.agent_reporters.values()) - agent_records = { - agent.unique_id: f(agent) for agent - in model.schedule.agents} - else: - agent_records = {} - for agent in model.schedule.agents: - agent_records[agent.unique_id] = tuple(rep(agent) for rep in self.agent_reporters.values()) - self.agent_vars[model.schedule.steps] = agent_records + f = self._repgetter(self.agent_reporters.values()) + agent_records = map(f, model.schedule.agents) + self._agent_records[model.schedule.steps] = agent_records def add_table_row(self, table_name, row, ignore_missing=False): """ Add a row dictionary to a specific table. @@ -187,6 +191,14 @@ def _make_attribute_collector(attr): Create a function which collects the value of a named attribute ''' + def attribute_name(attr): + """Decorator to add attribute name as a function attribute.""" + def wrapper(func): + func.attribute_name = attr + return func + return wrapper + + @attribute_name(attr) def attr_collector(obj): return getattr(obj, attr) @@ -209,11 +221,12 @@ def get_agent_vars_dataframe(self): """ data = defaultdict(dict) - for step, records in self.agent_vars.items(): - for agent_id, vals in records.items(): - for i, val in enumerate(vals): + for step, records in self._agent_records.items(): + for record in records: + agent_id = record[0] + for i, value in enumerate(record[1:]): var = list(self.agent_reporters)[i] - data[(step, agent_id)][var] = val + data[(step, agent_id)][var] = value df = pd.DataFrame.from_dict(data, orient="index") df.index.names = ["Step", "AgentID"] return df @@ -228,3 +241,8 @@ def get_table_dataframe(self, table_name): if table_name not in self.tables: raise Exception("No such table.") return pd.DataFrame(self.tables[table_name]) + + +def get_report(rep, agent): + """ Get a report from an agent.""" + return rep(agent) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 4a2d44cea4c..e9d8542776f 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -83,16 +83,17 @@ def test_model_vars(self): for element in data_collector.model_vars["model_value"]: assert element == 100 - def test_agent_vars(self): + def test_agent_records(self): ''' Test agent-level variable collection. ''' data_collector = self.model.datacollector - assert len(data_collector.agent_vars) == 7 - for step, records in data_collector.agent_vars.items(): + assert len(data_collector._agent_records) == 7 + for step, records in data_collector._agent_records.items(): + records = list(records) assert len(records) == 10 - for values in records.values(): - assert len(values) == 2 + for values in records: + assert len(values) == 3 def test_table_rows(self): ''' From 81181076df89595d3a3fce55529740948d869e55 Mon Sep 17 00:00:00 2001 From: Corvince Date: Fri, 24 Aug 2018 12:55:31 +0200 Subject: [PATCH 3/5] small bugfix --- mesa/datacollection.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index c9f262c8d9f..2d4f850e5e5 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -92,8 +92,8 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None): self._new_model_reporter(name, reporter) if agent_reporters is not None: - for name, reporter in agent_reporters.items(): - self._new_agent_reporter(name, reporter) + for name, reporter in agent_reporters.items(): + self._new_agent_reporter(name, reporter) if tables is not None: for name, columns in tables.items(): @@ -146,7 +146,9 @@ def _repgetter(reporters): """ if all([rep.__name__ == 'attr_collector' for rep in reporters]): # Fast path if all reporters are attribute collecters - report = attrgetter(*[rep.attribute_name for rep in reporters]) + def report(agent): + f = attrgetter(*[rep.attribute_name for rep in reporters]) + return (agent.unique_id, ) + f(agent) else: def report(agent): return (agent.unique_id, ) + tuple( From 920c9cabb15d7c171601fc2281b0f22d45a503fe Mon Sep 17 00:00:00 2001 From: Corvince Date: Tue, 11 Sep 2018 10:02:58 +0200 Subject: [PATCH 4/5] Revert to a sane approach --- mesa/datacollection.py | 57 ++++++++++++++----------------------- tests/test_datacollector.py | 3 +- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 2d4f850e5e5..bc1391c3456 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -36,7 +36,7 @@ * For collecting agent-level variables, agents must have a unique_id """ -from collections import defaultdict +import itertools from operator import attrgetter import pandas as pd @@ -136,25 +136,6 @@ def _new_table(self, table_name, table_columns): new_table = {column: [] for column in table_columns} self.tables[table_name] = new_table - @staticmethod - def _repgetter(reporters): - """ Get reports from a list of agents. - - Args: - reporters: List of reporter functions. - - """ - if all([rep.__name__ == 'attr_collector' for rep in reporters]): - # Fast path if all reporters are attribute collecters - def report(agent): - f = attrgetter(*[rep.attribute_name for rep in reporters]) - return (agent.unique_id, ) + f(agent) - else: - def report(agent): - return (agent.unique_id, ) + tuple( - get_report(rep, agent) for rep in reporters) - return report - def collect(self, model): """ Collect all the data for the given model object. """ if self.model_reporters: @@ -162,8 +143,17 @@ def collect(self, model): self.model_vars[var].append(reporter(model)) if self.agent_reporters: - f = self._repgetter(self.agent_reporters.values()) - agent_records = map(f, model.schedule.agents) + rep_funcs = self.agent_reporters.values() + if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]): + attributes = [func.attribute_name for func in rep_funcs] + get_reports = attrgetter( + *['model.schedule.steps', 'unique_id'] + attributes) + else: + def get_reports(agent): + return ( + agent.model.schedule.steps, agent.unique_id) + tuple( + rep(agent) for rep in rep_funcs) + agent_records = [*map(get_reports, model.schedule.agents)] self._agent_records[model.schedule.steps] = agent_records def add_table_row(self, table_name, row, ignore_missing=False): @@ -222,15 +212,15 @@ def get_agent_vars_dataframe(self): columns for tick and agent_id. """ - data = defaultdict(dict) - for step, records in self._agent_records.items(): - for record in records: - agent_id = record[0] - for i, value in enumerate(record[1:]): - var = list(self.agent_reporters)[i] - data[(step, agent_id)][var] = value - df = pd.DataFrame.from_dict(data, orient="index") - df.index.names = ["Step", "AgentID"] + all_records = itertools.chain.from_iterable( + self._agent_records.values()) + rep_names = [rep_name for rep_name in self.agent_reporters] + + df = pd.DataFrame.from_records( + data=all_records, + columns=["Step", "AgentID"] + rep_names, + ) + df = df.set_index(["Step", "AgentID"]) return df def get_table_dataframe(self, table_name): @@ -243,8 +233,3 @@ def get_table_dataframe(self, table_name): if table_name not in self.tables: raise Exception("No such table.") return pd.DataFrame(self.tables[table_name]) - - -def get_report(rep, agent): - """ Get a report from an agent.""" - return rep(agent) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index e9d8542776f..1a6a748a808 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -90,10 +90,9 @@ def test_agent_records(self): data_collector = self.model.datacollector assert len(data_collector._agent_records) == 7 for step, records in data_collector._agent_records.items(): - records = list(records) assert len(records) == 10 for values in records: - assert len(values) == 3 + assert len(values) == 4 def test_table_rows(self): ''' From 3f4f26b7cae68aef88bcc86f642965c778cbede3 Mon Sep 17 00:00:00 2001 From: Corvince Date: Tue, 11 Sep 2018 11:53:24 +0200 Subject: [PATCH 5/5] add _record_agents function --- mesa/datacollection.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index bc1391c3456..18360891efa 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -136,6 +136,21 @@ def _new_table(self, table_name, table_columns): new_table = {column: [] for column in table_columns} self.tables[table_name] = new_table + def _record_agents(self, model): + """ Record agents data in a mapping of functions and agents. """ + rep_funcs = self.agent_reporters.values() + if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]): + prefix = ['model.schedule.steps', 'unique_id'] + attributes = [func.attribute_name for func in rep_funcs] + get_reports = attrgetter(*prefix + attributes) + else: + def get_reports(agent): + prefix = (agent.model.schedule.steps, agent.unique_id) + reports = tuple(rep(agent) for rep in rep_funcs) + return prefix + reports + agent_records = map(get_reports, model.schedule.agents) + return agent_records + def collect(self, model): """ Collect all the data for the given model object. """ if self.model_reporters: @@ -143,18 +158,8 @@ def collect(self, model): self.model_vars[var].append(reporter(model)) if self.agent_reporters: - rep_funcs = self.agent_reporters.values() - if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]): - attributes = [func.attribute_name for func in rep_funcs] - get_reports = attrgetter( - *['model.schedule.steps', 'unique_id'] + attributes) - else: - def get_reports(agent): - return ( - agent.model.schedule.steps, agent.unique_id) + tuple( - rep(agent) for rep in rep_funcs) - agent_records = [*map(get_reports, model.schedule.agents)] - self._agent_records[model.schedule.steps] = agent_records + agent_records = self._record_agents(model) + self._agent_records[model.schedule.steps] = list(agent_records) def add_table_row(self, table_name, row, ignore_missing=False): """ Add a row dictionary to a specific table.