diff --git a/funsor/delta.py b/funsor/delta.py index d4582e18a..5102d384f 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -27,7 +27,7 @@ class Delta(Funsor): Normalized delta distribution binding a single variable. :param str name: Name of the bound variable. - :param Funsor value: Value of the bound variable. + :param Funsor point: Value of the bound variable. :param Funsor log_density: Optional log density to be added when evaluating at a point. This is needed to make :class:`Delta` closed under differentiable substitution. diff --git a/test/test_distributions.py b/test/test_distributions.py index ae1266249..622a0678a 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -8,6 +8,7 @@ import funsor import funsor.distributions as dist +from funsor.delta import Delta from funsor.domains import bint, reals from funsor.gaussian import Gaussian from funsor.terms import Variable @@ -80,6 +81,14 @@ def delta(v, log_density, value): assert_close(actual, expected) +def test_delta_delta(): + v = Variable('v', reals(2)) + point = Tensor(torch.randn(2)) + log_density = Tensor(torch.tensor(0.5)) + d = dist.Delta(point, log_density, v) + assert d is Delta('v', point, log_density) + + def test_normal_defaults(): loc = Variable('loc', reals()) scale = Variable('scale', reals())