Skip to content

Commit

Permalink
Log function and filename when FUNSOR_DEBUG=1 (#101)
Browse files Browse the repository at this point in the history
* Add function logging and filename logging when FUNSOR_DEBUG=1

* Enable more functions to be logged
  • Loading branch information
fritzo authored and eb8680 committed Mar 26, 2019
1 parent 951630c commit ea7feb2
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 8 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ format: FORCE

test: lint FORCE
pytest -v test
FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py
python examples/discrete_hmm.py -n 2
python examples/discrete_hmm.py -n 2 -t 50 --lazy
python examples/kalman_filter.py --xfail-if-not-implemented
Expand Down
2 changes: 2 additions & 0 deletions funsor/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
from collections import OrderedDict

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.optimizer import Finitary, optimize
from funsor.sum_product import _partition
Expand Down Expand Up @@ -51,6 +52,7 @@ def contractor(fn):
"""
Decorator for contract implementations to simplify inputs.
"""
fn = interpreter.debug_logged(fn)
return functools.partial(_simplify_contract, fn)


Expand Down
2 changes: 2 additions & 0 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
from collections import OrderedDict

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.contract import Contract
from funsor.terms import Funsor, Reduce, eager
Expand Down Expand Up @@ -58,6 +59,7 @@ def integrator(fn):
"""
Decorator for integration implementations.
"""
fn = interpreter.debug_logged(fn)
return functools.partial(_simplify_integrate, fn)


Expand Down
31 changes: 30 additions & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import, division, print_function

import functools
import inspect
import os
import re
import types
Expand All @@ -13,6 +15,7 @@
from funsor.registry import KeyedRegistry
from funsor.six import singledispatch

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_DEBUG = int(os.environ.get("FUNSOR_DEBUG", 0))
_STACK_SIZE = 0

Expand Down Expand Up @@ -117,12 +120,38 @@ def _reinterpret_ordereddict(x):
return OrderedDict((key, reinterpret(value)) for key, value in x.items())


if _DEBUG:
class DebugLogged(object):
def __init__(self, fn):
self.fn = fn
while isinstance(fn, functools.partial):
fn = fn.func
path = inspect.getabsfile(fn)
lineno = inspect.getsourcelines(fn)[1]
self._message = "{} file://{} {}".format(fn.__name__, path, lineno)

def __call__(self, *args, **kwargs):
print(' ' * _STACK_SIZE + self._message)
return self.fn(*args, **kwargs)

def debug_logged(fn):
if isinstance(fn, DebugLogged):
return fn
return DebugLogged(fn)
else:
def debug_logged(fn):
return fn


def dispatched_interpretation(fn):
"""
Decorator to create a dispatched interpretation function.
"""
registry = KeyedRegistry(default=lambda *args: None)
fn.register = registry.register
if _DEBUG:
fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn))
else:
fn.register = registry.register
fn.dispatch = registry.__call__
return fn

Expand Down
2 changes: 2 additions & 0 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from six import add_metaclass
from six.moves import reduce

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import reals
Expand Down Expand Up @@ -293,6 +294,7 @@ def _joint_integrator(fn):
"""
Decorator for Integrate(Joint(...), ...) patterns.
"""
fn = interpreter.debug_logged(fn)
return integrator(functools.partial(_simplify_integrate, fn))


Expand Down
1 change: 0 additions & 1 deletion funsor/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from __future__ import absolute_import, division, print_function

from collections import defaultdict
Expand Down
12 changes: 6 additions & 6 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def sample(self, sampled_vars, sample_inputs=None):
if sampled_vars.isdisjoint(self.inputs):
return self

result = self.unscaled_sample(sampled_vars, sample_inputs)
result = interpreter.debug_logged(self.unscaled_sample)(sampled_vars, sample_inputs)
if sample_inputs is not None:
log_scale = 0
for var, domain in sample_inputs.items():
Expand Down Expand Up @@ -614,7 +614,7 @@ def eager_subs(arg, subs):
assert isinstance(subs, tuple)
if not any(k in arg.inputs for k, v in subs):
return arg
return arg.eager_subs(subs)
return interpreter.debug_logged(arg.eager_subs)(subs)


_PREFIX = {
Expand Down Expand Up @@ -647,14 +647,14 @@ def eager_subs(self, subs):

@eager.register(Unary, Op, Funsor)
def eager_unary(op, arg):
return arg.eager_unary(op)
return interpreter.debug_logged(arg.eager_unary)(op)


@eager.register(Unary, AssociativeOp, Funsor)
def eager_unary(op, arg):
if not arg.output.shape:
return arg
return arg.eager_unary(op)
return interpreter.debug_logged(arg.eager_unary)(op)


_INFIX = {
Expand Down Expand Up @@ -730,12 +730,12 @@ def eager_reduce(self, op, reduced_vars):

@eager.register(Reduce, AssociativeOp, Funsor, frozenset)
def eager_reduce(op, arg, reduced_vars):
return arg.eager_reduce(op, reduced_vars)
return interpreter.debug_logged(arg.eager_reduce)(op, reduced_vars)


@sequential.register(Reduce, AssociativeOp, Funsor, frozenset)
def sequential_reduce(op, arg, reduced_vars):
return arg.sequential_reduce(op, reduced_vars)
return interpreter.debug_logged(arg.sequential_reduce)(op, reduced_vars)


class NumberMeta(FunsorMeta):
Expand Down

0 comments on commit ea7feb2

Please sign in to comment.