Skip to content

Commit

Permalink
Support for pattern matching with the unification library (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Mar 27, 2019
1 parent 364745e commit a54defa
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 0 deletions.
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
joint,
montecarlo,
ops,
pattern,
sum_product,
terms,
torch
Expand Down Expand Up @@ -48,6 +49,7 @@
'montecarlo',
'of_shape',
'ops',
'pattern',
'reals',
'reinterpret',
'sum_product',
Expand Down
69 changes: 69 additions & 0 deletions funsor/pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import absolute_import, division, print_function

import functools

import unification.match
from unification import unify
from unification.variable import isvar

import funsor.ops as ops
from funsor.interpreter import dispatched_interpretation, interpretation
from funsor.terms import Binary, Funsor, Variable, lazy


@dispatched_interpretation
def unify_interpreter(cls, *args):
result = unify_interpreter.dispatch(cls, *args)
if result is None:
result = lazy(cls, *args)
return result


@unify_interpreter.register(Binary, ops.Op, Funsor, Funsor)
def unify_eq(op, lhs, rhs):
if op is ops.eq:
return lhs is rhs # equality via cons-hashing
return None


class EqDispatcher(unification.match.Dispatcher):

resolve = interpretation(unify_interpreter)(unification.match.Dispatcher.resolve)


class EqVarDispatcher(EqDispatcher):

def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = dict((k.name if isinstance(k, Variable) else k.token, v) for k, v in s.items())
return func(**d)


@isvar.register(Variable)
def _isvar_funsor_variable(v):
return True


@unify.register(Funsor, Funsor, dict)
@interpretation(unify_interpreter)
def unify_funsor(pattern, expr, subs):
if type(pattern) is not type(expr):
return False
return unify(pattern._ast_values, expr._ast_values, subs)


@unify.register(Variable, (Variable, Funsor), dict)
@unify.register((Variable, Funsor), Variable, dict)
def unify_patternvar(pattern, expr, subs):
subs.update({pattern: expr} if isinstance(pattern, Variable) else {expr: pattern})
return subs


match_vars = functools.partial(unification.match.match, Dispatcher=EqVarDispatcher)
match = functools.partial(unification.match.match, Dispatcher=EqDispatcher)


__all__ = [
"match",
"match_vars",
]
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'pyro-ppl>=0.3',
'six>=1.10.0',
'torch>=1.0.0',
'unification',
],
extras_require={
'test': ['flake8', 'pytest>=4.1'],
Expand Down
42 changes: 42 additions & 0 deletions test/test_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import absolute_import, division, print_function

from funsor.domains import reals
from funsor.interpreter import interpretation, reinterpret
from funsor.terms import Number, Variable, lazy

from funsor.pattern import match, match_vars, unify_interpreter
from unification import unify


def test_unify_binary():
with interpretation(lazy):
pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals())
expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.))

subs = unify(pattern, expr)
print(subs, pattern(**{k.name: v for k, v in subs.items()}))
assert subs is not False

with interpretation(unify_interpreter):
assert unify((pattern,), (expr,)) is not False


def test_match_binary():
with interpretation(lazy):
pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals())
expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.))

@match_vars(pattern)
def expand_2_vars(a, b):
return a + b + b

@match(pattern)
def expand_2_walk(x):
return x.lhs + x.rhs.rhs + x.rhs.rhs

eager_val = reinterpret(expr)
lazy_val = expand_2_vars(expr)
assert eager_val == reinterpret(lazy_val)

lazy_val_2 = expand_2_walk(expr)
assert eager_val == reinterpret(lazy_val_2)

0 comments on commit a54defa

Please sign in to comment.