Skip to content

Commit

Permalink
Remove .log_density field from Delta funsor
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Mar 10, 2019
1 parent 3825deb commit 897f523
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 20 deletions.
21 changes: 6 additions & 15 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@

class DeltaMeta(FunsorMeta):
"""
Wrapper to fill in defaults.
Wrapper to convert point to a funsor.
"""
def __call__(cls, name, point, log_density=0):
def __call__(cls, name, point):
point = to_funsor(point)
log_density = to_funsor(log_density)
return super(DeltaMeta, cls).__call__(name, point, log_density)
return super(DeltaMeta, cls).__call__(name, point)


@add_metaclass(DeltaMeta)
Expand All @@ -28,23 +27,16 @@ class Delta(Funsor):
:param str name: Name of the bound variable.
:param Funsor point: Value of the bound variable.
:param Funsor log_density: Optional log density to be added when evaluating
at a point. This is needed to make :class:`Delta` closed under
differentiable substitution.
"""
def __init__(self, name, point, log_density=0):
def __init__(self, name, point):
assert isinstance(name, str)
assert isinstance(point, Funsor)
assert isinstance(log_density, Funsor)
assert log_density.output == reals()
inputs = OrderedDict([(name, point.output)])
inputs.update(point.inputs)
inputs.update(log_density.inputs)
output = reals()
super(Delta, self).__init__(inputs, output)
self.name = name
self.point = point
self.log_density = log_density

def eager_subs(self, subs):
value = None
Expand All @@ -62,17 +54,16 @@ def eager_subs(self, subs):

name = self.name
point = self.point.eager_subs(index_part)
log_density = self.log_density.eager_subs(index_part)
if value is not None:
if isinstance(value, Variable):
name = value.name
elif isinstance(value, (Number, Tensor)) and isinstance(point, (Number, Tensor)):
return (value == point).all().log() + log_density
return (value == point).all().log()
else:
# TODO Compute a jacobian, update log_prob, and emit another Delta.
raise ValueError('Cannot substitute a {} into a Delta'
.format(type(value).__name__))
return Delta(name, point, log_density)
return Delta(name, point)

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
Expand Down
7 changes: 3 additions & 4 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,16 @@ def eager_delta(v, log_density, value):
return Tensor(data, inputs)


@eager.register(Delta, Funsor, Funsor, Variable)
@eager.register(Delta, Variable, Funsor, Variable)
@eager.register(Delta, (Funsor, Variable), Funsor, Variable)
def eager_delta(v, log_density, value):
assert v.output == value.output
return funsor.delta.Delta(value.name, v, log_density)
return funsor.delta.Delta(value.name, v) + log_density


@eager.register(Delta, Variable, Funsor, Funsor)
def eager_delta(v, log_density, value):
assert v.output == value.output
return funsor.delta.Delta(v.name, value, log_density)
return funsor.delta.Delta(v.name, value) + log_density


class Normal(Distribution):
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_delta_delta():
point = Tensor(torch.randn(2))
log_density = Tensor(torch.tensor(0.5))
d = dist.Delta(point, log_density, v)
assert d is Delta('v', point, log_density)
assert d is Delta('v', point) + log_density


def test_normal_defaults():
Expand Down

0 comments on commit 897f523

Please sign in to comment.