-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for pattern matching with the unification library (#78)
- Loading branch information
Showing
4 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import functools | ||
|
||
import unification.match | ||
from unification import unify | ||
from unification.variable import isvar | ||
|
||
import funsor.ops as ops | ||
from funsor.interpreter import dispatched_interpretation, interpretation | ||
from funsor.terms import Binary, Funsor, Variable, lazy | ||
|
||
|
||
@dispatched_interpretation | ||
def unify_interpreter(cls, *args): | ||
result = unify_interpreter.dispatch(cls, *args) | ||
if result is None: | ||
result = lazy(cls, *args) | ||
return result | ||
|
||
|
||
@unify_interpreter.register(Binary, ops.Op, Funsor, Funsor) | ||
def unify_eq(op, lhs, rhs): | ||
if op is ops.eq: | ||
return lhs is rhs # equality via cons-hashing | ||
return None | ||
|
||
|
||
class EqDispatcher(unification.match.Dispatcher): | ||
|
||
resolve = interpretation(unify_interpreter)(unification.match.Dispatcher.resolve) | ||
|
||
|
||
class EqVarDispatcher(EqDispatcher): | ||
|
||
def __call__(self, *args, **kwargs): | ||
func, s = self.resolve(args) | ||
d = dict((k.name if isinstance(k, Variable) else k.token, v) for k, v in s.items()) | ||
return func(**d) | ||
|
||
|
||
@isvar.register(Variable) | ||
def _isvar_funsor_variable(v): | ||
return True | ||
|
||
|
||
@unify.register(Funsor, Funsor, dict) | ||
@interpretation(unify_interpreter) | ||
def unify_funsor(pattern, expr, subs): | ||
if type(pattern) is not type(expr): | ||
return False | ||
return unify(pattern._ast_values, expr._ast_values, subs) | ||
|
||
|
||
@unify.register(Variable, (Variable, Funsor), dict) | ||
@unify.register((Variable, Funsor), Variable, dict) | ||
def unify_patternvar(pattern, expr, subs): | ||
subs.update({pattern: expr} if isinstance(pattern, Variable) else {expr: pattern}) | ||
return subs | ||
|
||
|
||
match_vars = functools.partial(unification.match.match, Dispatcher=EqVarDispatcher) | ||
match = functools.partial(unification.match.match, Dispatcher=EqDispatcher) | ||
|
||
|
||
__all__ = [ | ||
"match", | ||
"match_vars", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from funsor.domains import reals | ||
from funsor.interpreter import interpretation, reinterpret | ||
from funsor.terms import Number, Variable, lazy | ||
|
||
from funsor.pattern import match, match_vars, unify_interpreter | ||
from unification import unify | ||
|
||
|
||
def test_unify_binary(): | ||
with interpretation(lazy): | ||
pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) | ||
expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) | ||
|
||
subs = unify(pattern, expr) | ||
print(subs, pattern(**{k.name: v for k, v in subs.items()})) | ||
assert subs is not False | ||
|
||
with interpretation(unify_interpreter): | ||
assert unify((pattern,), (expr,)) is not False | ||
|
||
|
||
def test_match_binary(): | ||
with interpretation(lazy): | ||
pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) | ||
expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) | ||
|
||
@match_vars(pattern) | ||
def expand_2_vars(a, b): | ||
return a + b + b | ||
|
||
@match(pattern) | ||
def expand_2_walk(x): | ||
return x.lhs + x.rhs.rhs + x.rhs.rhs | ||
|
||
eager_val = reinterpret(expr) | ||
lazy_val = expand_2_vars(expr) | ||
assert eager_val == reinterpret(lazy_val) | ||
|
||
lazy_val_2 = expand_2_walk(expr) | ||
assert eager_val == reinterpret(lazy_val_2) |