diff --git a/Makefile b/Makefile index 7686b3dda..2d6bab619 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ format: FORCE test: lint FORCE pytest -v test + FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py python examples/discrete_hmm.py -n 2 python examples/discrete_hmm.py -n 2 -t 50 --lazy python examples/kalman_filter.py --xfail-if-not-implemented diff --git a/funsor/contract.py b/funsor/contract.py index e7d374d0e..ef002e875 100644 --- a/funsor/contract.py +++ b/funsor/contract.py @@ -3,6 +3,7 @@ import functools from collections import OrderedDict +import funsor.interpreter as interpreter import funsor.ops as ops from funsor.optimizer import Finitary, optimize from funsor.sum_product import _partition @@ -51,6 +52,7 @@ def contractor(fn): """ Decorator for contract implementations to simplify inputs. """ + fn = interpreter.debug_logged(fn) return functools.partial(_simplify_contract, fn) diff --git a/funsor/integrate.py b/funsor/integrate.py index 7799b1aac..fc691f708 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -3,6 +3,7 @@ import functools from collections import OrderedDict +import funsor.interpreter as interpreter import funsor.ops as ops from funsor.contract import Contract from funsor.terms import Funsor, Reduce, eager @@ -58,6 +59,7 @@ def integrator(fn): """ Decorator for integration implementations. """ + fn = interpreter.debug_logged(fn) return functools.partial(_simplify_integrate, fn) diff --git a/funsor/interpreter.py b/funsor/interpreter.py index bab0205ba..2cdba6834 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import functools +import inspect import os import re import types @@ -13,6 +15,7 @@ from funsor.registry import KeyedRegistry from funsor.six import singledispatch +_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) _DEBUG = int(os.environ.get("FUNSOR_DEBUG", 0)) _STACK_SIZE = 0 @@ -117,12 +120,38 @@ def _reinterpret_ordereddict(x): return OrderedDict((key, reinterpret(value)) for key, value in x.items()) +if _DEBUG: + class DebugLogged(object): + def __init__(self, fn): + self.fn = fn + while isinstance(fn, functools.partial): + fn = fn.func + path = inspect.getabsfile(fn) + lineno = inspect.getsourcelines(fn)[1] + self._message = "{} file://{} {}".format(fn.__name__, path, lineno) + + def __call__(self, *args, **kwargs): + print(' ' * _STACK_SIZE + self._message) + return self.fn(*args, **kwargs) + + def debug_logged(fn): + if isinstance(fn, DebugLogged): + return fn + return DebugLogged(fn) +else: + def debug_logged(fn): + return fn + + def dispatched_interpretation(fn): """ Decorator to create a dispatched interpretation function. """ registry = KeyedRegistry(default=lambda *args: None) - fn.register = registry.register + if _DEBUG: + fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn)) + else: + fn.register = registry.register fn.dispatch = registry.__call__ return fn diff --git a/funsor/joint.py b/funsor/joint.py index 6ff1865da..72b93713e 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -6,6 +6,7 @@ from six import add_metaclass from six.moves import reduce +import funsor.interpreter as interpreter import funsor.ops as ops from funsor.delta import Delta from funsor.domains import reals @@ -293,6 +294,7 @@ def _joint_integrator(fn): """ Decorator for Integrate(Joint(...), ...) patterns. """ + fn = interpreter.debug_logged(fn) return integrator(functools.partial(_simplify_integrate, fn)) diff --git a/funsor/registry.py b/funsor/registry.py index 67b5bb7b4..086b1c35d 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -1,4 +1,3 @@ - from __future__ import absolute_import, division, print_function from collections import defaultdict diff --git a/funsor/terms.py b/funsor/terms.py index 94938202e..f182f7ff4 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -258,7 +258,7 @@ def sample(self, sampled_vars, sample_inputs=None): if sampled_vars.isdisjoint(self.inputs): return self - result = self.unscaled_sample(sampled_vars, sample_inputs) + result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars, sample_inputs) if sample_inputs is not None: log_scale = 0 for var, domain in sample_inputs.items(): @@ -614,7 +614,7 @@ def eager_subs(arg, subs): assert isinstance(subs, tuple) if not any(k in arg.inputs for k, v in subs): return arg - return arg.eager_subs(subs) + return interpreter.debug_logged(arg.eager_subs)(subs) _PREFIX = { @@ -647,14 +647,14 @@ def eager_subs(self, subs): @eager.register(Unary, Op, Funsor) def eager_unary(op, arg): - return arg.eager_unary(op) + return interpreter.debug_logged(arg.eager_unary)(op) @eager.register(Unary, AssociativeOp, Funsor) def eager_unary(op, arg): if not arg.output.shape: return arg - return arg.eager_unary(op) + return interpreter.debug_logged(arg.eager_unary)(op) _INFIX = { @@ -730,12 +730,12 @@ def eager_reduce(self, op, reduced_vars): @eager.register(Reduce, AssociativeOp, Funsor, frozenset) def eager_reduce(op, arg, reduced_vars): - return arg.eager_reduce(op, reduced_vars) + return interpreter.debug_logged(arg.eager_reduce)(op, reduced_vars) @sequential.register(Reduce, AssociativeOp, Funsor, frozenset) def sequential_reduce(op, arg, reduced_vars): - return arg.sequential_reduce(op, reduced_vars) + return interpreter.debug_logged(arg.sequential_reduce)(op, reduced_vars) class NumberMeta(FunsorMeta):