Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constraints for multiple intervals #1829

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"unit_interval",
"zero_sum",
"Constraint",
"union_of_closed_intervals",
"intersection_of_closed_intervals",
]

import math
Expand Down Expand Up @@ -766,6 +768,140 @@ def tree_flatten(self):
return (self.event_dim,), (("event_dim",), dict())


class _UnionOfClosedIntervals(Constraint):
"""A constraint representing the union of multiple intervals."""

event_dim = 1

def __init__(self, lower_bounds, upper_bounds) -> None:
assert isinstance(lower_bounds, (list, tuple, jnp.ndarray))
assert isinstance(upper_bounds, (list, tuple, jnp.ndarray))
assert len(lower_bounds) == len(upper_bounds), (
f"lower_bounds and upper_bounds must have the same length, "
f"but got {len(lower_bounds)} and {len(upper_bounds)}"
)
self.lower_bounds = jnp.asarray(lower_bounds)
self.upper_bounds = jnp.asarray(upper_bounds)

def __call__(self, x):
r"""Check if the input is within the union of intervals

.. math::
x \in \bigcup_{i=1}^{n} [a_i, b_i] \implies \bigvee_{i=1}^{n} (x \in [a_i, b_i])

:param x: The input to be checked.
"""
return jnp.any((x >= self.lower_bounds) & (x <= self.upper_bounds), axis=-1)

def feasible_like(self, prototype):
return jnp.broadcast_to(
(self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype)
)

def tree_flatten(self):
return (self.lower_bounds, self.upper_bounds), (
("lower_bounds", "upper_bounds"),
dict(),
)

def __eq__(self, other):
if not isinstance(other, _UnionOfClosedIntervals):
return False
return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal(
self.upper_bounds, other.upper_bounds
)


class _IntersectionOfClosedIntervals(Constraint):
"""A constraint representing the intersection of multiple intervals."""

event_dim = 1

def __init__(self, lower_bounds, upper_bounds) -> None:
assert isinstance(lower_bounds, (list, tuple, jnp.ndarray))
assert isinstance(upper_bounds, (list, tuple, jnp.ndarray))
assert len(lower_bounds) == len(upper_bounds), (
f"lower_bounds and upper_bounds must have the same length, "
f"but got {len(lower_bounds)} and {len(upper_bounds)}"
)
self.lower_bounds = jnp.asarray(lower_bounds)
self.upper_bounds = jnp.asarray(upper_bounds)

def __call__(self, x):
r"""Check if the input is within the union of intervals

.. math::
x \in \bigcap_{i=1}^{n} [a_i, b_i] \implies \bigwedge_{i=1}^{n} (x \in [a_i, b_i])

:param x: The input to be checked.
"""
return jnp.all((x >= self.lower_bounds) & (x <= self.upper_bounds), axis=-1)

def feasible_like(self, prototype):
return jnp.broadcast_to(
(self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype)
)

def tree_flatten(self):
return (self.lower_bounds, self.upper_bounds), (
("lower_bounds", "upper_bounds"),
dict(),
)

def __eq__(self, other):
if not isinstance(other, _IntersectionOfClosedIntervals):
return False
return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal(
self.upper_bounds, other.upper_bounds
)


class _UniqueIntervals(Constraint):
"""A constraint representing a set of unique intervals for a single dimension."""

event_dim = 1

def __init__(self, lower_bounds, upper_bounds) -> None:
assert isinstance(lower_bounds, (list, tuple, jnp.ndarray))
assert isinstance(upper_bounds, (list, tuple, jnp.ndarray))
assert len(lower_bounds) == len(upper_bounds), (
f"lower_bounds and upper_bounds must have the same length, "
f"but got {len(lower_bounds)} and {len(upper_bounds)}"
)
self.lower_bounds = jnp.asarray(lower_bounds)
self.upper_bounds = jnp.asarray(upper_bounds)

def __call__(self, x):
r"""Check if the input is within the specified intervals

.. math::
\bigwedge_{i=1}^{n} (x_i \in [a_i, b_i])

:param x: The input to be checked.
"""
less_than = jnp.all(x <= self.upper_bounds, axis=-1)
greater_than = jnp.all(x >= self.lower_bounds, axis=-1)
return less_than & greater_than

def feasible_like(self, prototype):
return jnp.broadcast_to(
(self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype)
)

def tree_flatten(self):
return (self.lower_bounds, self.upper_bounds), (
("lower_bounds", "upper_bounds"),
dict(),
)

def __eq__(self, other):
if not isinstance(other, _UniqueIntervals):
return False
return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal(
self.upper_bounds, other.upper_bounds
)


# TODO: Make types consistent
# See https://github.com/pytorch/pytorch/issues/50616

Expand Down Expand Up @@ -805,3 +941,6 @@ def tree_flatten(self):
unit_interval = _UnitInterval()
open_interval = _OpenInterval
zero_sum = _ZeroSum
union_of_closed_intervals = _UnionOfClosedIntervals
intersection_of_closed_intervals = _IntersectionOfClosedIntervals
unique_intervals = _UniqueIntervals
26 changes: 26 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,3 +1636,29 @@ def _transform_to_simplex(constraint):
@biject_to.register(constraints.zero_sum)
def _transform_to_zero_sum(constraint):
return ZeroSumTransform(constraint.event_dim)


@biject_to.register(constraints.union_of_closed_intervals)
@biject_to.register(constraints.intersection_of_closed_intervals)
def _transform_to_union_of_intervals(constraint):
return ComposeTransform(
[
ReshapeTransform(
forward_shape=constraint.lower_bounds.shape,
inverse_shape=(constraint.lower_bounds.size,),
),
ExpTransform(),
Copy link
Member

@fehiepsi fehiepsi Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this composed transform is bijective. Could you clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this code snippet to test the constraints of the system. This was the best example I could come up with based on a quick review of the test cases. Please note that this code is not based on any specific logic, so feel free to correct me if I'm wrong or suggest improvements!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Probably the tests pass because the domain is quite loose. Exp transform has positive codomain, not union of intervals.

]
)


@biject_to.register(constraints.unique_intervals)
def _transform_to_unique_intervals(constraint):
return ComposeTransform(
[
ReshapeTransform(
forward_shape=constraint.lower_bounds.shape,
inverse_shape=(constraint.lower_bounds.size,),
),
]
)
13 changes: 13 additions & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
),
"open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()),
"zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)),
"union_of_closed_intervals": T(
constraints.union_of_closed_intervals,
(_a((-100, -50, 0, 50)), _a((-50, 0, 50, 100))),
dict(),
),
"intersection_of_closed_intervals": T(
constraints.intersection_of_closed_intervals,
(_a((-100, -130, -150, -250)), _a((200, 100, 250, 100))),
dict(),
),
"unique_intervals": T(
constraints.unique_intervals, (_a((-10, -8, -6, -4)), _a((4, 6, 8, 10))), dict()
),
}

# TODO: BijectorConstraint
Expand Down
16 changes: 16 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,22 @@ def test_batched_recursive_linear_transform():
(constraints.softplus_positive, (2,)),
(constraints.unit_interval, (4,)),
(constraints.nonnegative, (7,)),
(
constraints.union_of_closed_intervals(
(-100, -50, 0, 50), (-50, 0, 50, 100)
),
(4,),
),
(
constraints.intersection_of_closed_intervals(
(-100, -130, -150, -250), (200, 100, 250, 100)
),
(4,),
),
(
constraints.unique_intervals((-10, -8, -6, -4), (4, 6, 8, 10)),
(4,),
),
],
ids=str,
)
Expand Down
Loading