Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to more accurate way of calculating ancestors #3213

Merged
merged 1 commit into from
Oct 11, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from theano.gof.graph import inputs
import itertools

from theano.gof.graph import ancestors

from .util import get_default_varnames
import pymc3 as pm


def powerset(iterable):
"""All *nonempty* subsets of an iterable.

From itertools docs.

powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
"""
s = list(iterable)
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(1, len(s)+1))


class ModelGraph(object):
def __init__(self, model):
self.model = model
Expand All @@ -21,30 +34,29 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _inputs(self, var, func, blockers=None):
"""Get inputs to a function that are also named PyMC3 variables"""
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var])
def _ancestors(self, var, func, blockers=None):
"""Get ancestors of a function that are also named PyMC3 variables"""
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])

def _get_inputs(self, var, func):
"""Get all inputs to a function, doing some accounting for deterministics
def _get_ancestors(self, var, func):
"""Get all ancestors of a function, doing some accounting for deterministics

Specifically, if a deterministic is an input, theano.gof.graph.inputs will
Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
return only the inputs *to the deterministic*. However, if we pass in the
deterministic as a blocker, it will skip those nodes.
"""
deterministics = self.get_deterministics(var)
upstream = self._inputs(var, func)
parents = self._inputs(var, func, blockers=deterministics)
if parents != upstream:
det_map = {}
for d in deterministics:
d_set = {j for j in inputs([func], blockers=[d])}
if upstream - d_set:
det_map[d] = d_set
for d, d_set in det_map.items():
if all(d_set.issubset(other) for other in det_map.values()):
parents.add(d)
return parents
upstream = self._ancestors(var, func)

# Usual case
if upstream == self._ancestors(var, func, blockers=upstream):
return upstream
else: # deterministic accounting
for d in powerset(upstream):
blocked = self._ancestors(var, func, blockers=d)
if set(d) == blocked:
return d
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')

def _filter_parents(self, var, parents):
"""Get direct parents of a var, as strings"""
Expand All @@ -70,7 +82,7 @@ def get_parents(self, var):
else:
func = var

parents = self._get_inputs(var, func)
parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
Expand Down