From c5cf2b64bc5126dc149875fbef07b583fca7d811 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Sun, 7 Jul 2024 01:30:20 +0500 Subject: [PATCH 1/3] chore: union of closed intervals --- numpyro/distributions/constraints.py | 46 ++++++++++++++++++++++++++++ numpyro/distributions/transforms.py | 13 ++++++++ test/test_constraints.py | 5 +++ test/test_transforms.py | 6 ++++ 4 files changed, 70 insertions(+) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 56075078f..79ce55320 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -60,6 +60,7 @@ "unit_interval", "zero_sum", "Constraint", + "union_of_closed_intervals", ] import math @@ -766,6 +767,50 @@ def tree_flatten(self): return (self.event_dim,), (("event_dim",), dict()) +class _UnionOfClosedIntervals(Constraint): + """A constraint representing the union of multiple intervals.""" + + event_dim = 1 + + def __init__(self, lower_bounds, upper_bounds) -> None: + assert isinstance(lower_bounds, (list, tuple, jnp.ndarray)) + assert isinstance(upper_bounds, (list, tuple, jnp.ndarray)) + assert len(lower_bounds) == len(upper_bounds), ( + f"lower_bounds and upper_bounds must have the same length, " + f"but got {len(lower_bounds)} and {len(upper_bounds)}" + ) + self.lower_bounds = jnp.asarray(lower_bounds) + self.upper_bounds = jnp.asarray(upper_bounds) + + def __call__(self, x): + r"""Check if the input is within the union of intervals + + .. math:: + x \in \bigcup_{i=1}^{n} [a_i, b_i] \implies \bigvee_{i=1}^{n} (x \in [a_i, b_i]) + + :param x: The input to be checked. + """ + return jnp.any((x >= self.lower_bounds) & (x <= self.upper_bounds), axis=-1) + + def feasible_like(self, prototype): + return jnp.broadcast_to( + (self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype) + ) + + def tree_flatten(self): + return (self.lower_bounds, self.upper_bounds), ( + ("lower_bounds", "upper_bounds"), + dict(), + ) + + def __eq__(self, other): + if not isinstance(other, _UnionOfClosedIntervals): + return False + return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal( + self.upper_bounds, other.upper_bounds + ) + + # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 @@ -805,3 +850,4 @@ def tree_flatten(self): unit_interval = _UnitInterval() open_interval = _OpenInterval zero_sum = _ZeroSum +union_of_closed_intervals = _UnionOfClosedIntervals diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 40c31fa51..44fc9bc46 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1636,3 +1636,16 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): return ZeroSumTransform(constraint.event_dim) + + +@biject_to.register(constraints.union_of_closed_intervals) +def _transform_to_union_of_intervals(constraint): + return ComposeTransform( + [ + ReshapeTransform( + forward_shape=constraint.lower_bounds.shape, + inverse_shape=(constraint.lower_bounds.size,), + ), + ExpTransform(), + ] + ) diff --git a/test/test_constraints.py b/test/test_constraints.py index acd96732e..9de7c4716 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -68,6 +68,11 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): ), "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()), "zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)), + "union_of_closed_intervals": T( + constraints.union_of_closed_intervals, + (_a((-100, -50, 0, 50)), _a((-50, 0, 50, 100))), + dict(), + ), } # TODO: BijectorConstraint diff --git a/test/test_transforms.py b/test/test_transforms.py index 997959244..a7c9de918 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -383,6 +383,12 @@ def test_batched_recursive_linear_transform(): (constraints.softplus_positive, (2,)), (constraints.unit_interval, (4,)), (constraints.nonnegative, (7,)), + ( + constraints.union_of_closed_intervals( + (-100, -50, 0, 50), (-50, 0, 50, 100) + ), + (4,), + ), ], ids=str, ) From 5490b5829ff2047bee27206b3f61592cf89489ff Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Sun, 7 Jul 2024 01:45:27 +0500 Subject: [PATCH 2/3] chore: intersection of closed intervals --- numpyro/distributions/constraints.py | 46 ++++++++++++++++++++++++++++ numpyro/distributions/transforms.py | 1 + test/test_constraints.py | 5 +++ test/test_transforms.py | 6 ++++ 4 files changed, 58 insertions(+) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 79ce55320..923f7afa3 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -61,6 +61,7 @@ "zero_sum", "Constraint", "union_of_closed_intervals", + "intersection_of_closed_intervals", ] import math @@ -811,6 +812,50 @@ def __eq__(self, other): ) +class _IntersectionOfClosedIntervals(Constraint): + """A constraint representing the intersection of multiple intervals.""" + + event_dim = 1 + + def __init__(self, lower_bounds, upper_bounds) -> None: + assert isinstance(lower_bounds, (list, tuple, jnp.ndarray)) + assert isinstance(upper_bounds, (list, tuple, jnp.ndarray)) + assert len(lower_bounds) == len(upper_bounds), ( + f"lower_bounds and upper_bounds must have the same length, " + f"but got {len(lower_bounds)} and {len(upper_bounds)}" + ) + self.lower_bounds = jnp.asarray(lower_bounds) + self.upper_bounds = jnp.asarray(upper_bounds) + + def __call__(self, x): + r"""Check if the input is within the union of intervals + + .. math:: + x \in \bigcap_{i=1}^{n} [a_i, b_i] \implies \bigwedge_{i=1}^{n} (x \in [a_i, b_i]) + + :param x: The input to be checked. + """ + return jnp.all((x >= self.lower_bounds) & (x <= self.upper_bounds), axis=-1) + + def feasible_like(self, prototype): + return jnp.broadcast_to( + (self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype) + ) + + def tree_flatten(self): + return (self.lower_bounds, self.upper_bounds), ( + ("lower_bounds", "upper_bounds"), + dict(), + ) + + def __eq__(self, other): + if not isinstance(other, _IntersectionOfClosedIntervals): + return False + return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal( + self.upper_bounds, other.upper_bounds + ) + + # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 @@ -851,3 +896,4 @@ def __eq__(self, other): open_interval = _OpenInterval zero_sum = _ZeroSum union_of_closed_intervals = _UnionOfClosedIntervals +intersection_of_closed_intervals = _IntersectionOfClosedIntervals diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 44fc9bc46..42dd26fb1 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1639,6 +1639,7 @@ def _transform_to_zero_sum(constraint): @biject_to.register(constraints.union_of_closed_intervals) +@biject_to.register(constraints.intersection_of_closed_intervals) def _transform_to_union_of_intervals(constraint): return ComposeTransform( [ diff --git a/test/test_constraints.py b/test/test_constraints.py index 9de7c4716..580b938f7 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -73,6 +73,11 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): (_a((-100, -50, 0, 50)), _a((-50, 0, 50, 100))), dict(), ), + "intersection_of_closed_intervals": T( + constraints.intersection_of_closed_intervals, + (_a((-100, -130, -150, -250)), _a((200, 100, 250, 100))), + dict(), + ), } # TODO: BijectorConstraint diff --git a/test/test_transforms.py b/test/test_transforms.py index a7c9de918..aadaee48f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -389,6 +389,12 @@ def test_batched_recursive_linear_transform(): ), (4,), ), + ( + constraints.intersection_of_closed_intervals( + (-100, -130, -150, -250), (200, 100, 250, 100) + ), + (4,), + ), ], ids=str, ) From 2a102fd2757a281a752f22c14a84dca1ec900628 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Sun, 7 Jul 2024 02:01:41 +0500 Subject: [PATCH 3/3] chore: unique intervals --- numpyro/distributions/constraints.py | 47 ++++++++++++++++++++++++++++ numpyro/distributions/transforms.py | 12 +++++++ test/test_constraints.py | 3 ++ test/test_transforms.py | 4 +++ 4 files changed, 66 insertions(+) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 923f7afa3..21c06a44b 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -856,6 +856,52 @@ def __eq__(self, other): ) +class _UniqueIntervals(Constraint): + """A constraint representing a set of unique intervals for a single dimension.""" + + event_dim = 1 + + def __init__(self, lower_bounds, upper_bounds) -> None: + assert isinstance(lower_bounds, (list, tuple, jnp.ndarray)) + assert isinstance(upper_bounds, (list, tuple, jnp.ndarray)) + assert len(lower_bounds) == len(upper_bounds), ( + f"lower_bounds and upper_bounds must have the same length, " + f"but got {len(lower_bounds)} and {len(upper_bounds)}" + ) + self.lower_bounds = jnp.asarray(lower_bounds) + self.upper_bounds = jnp.asarray(upper_bounds) + + def __call__(self, x): + r"""Check if the input is within the specified intervals + + .. math:: + \bigwedge_{i=1}^{n} (x_i \in [a_i, b_i]) + + :param x: The input to be checked. + """ + less_than = jnp.all(x <= self.upper_bounds, axis=-1) + greater_than = jnp.all(x >= self.lower_bounds, axis=-1) + return less_than & greater_than + + def feasible_like(self, prototype): + return jnp.broadcast_to( + (self.lower_bounds + self.upper_bounds) / 2, jnp.shape(prototype) + ) + + def tree_flatten(self): + return (self.lower_bounds, self.upper_bounds), ( + ("lower_bounds", "upper_bounds"), + dict(), + ) + + def __eq__(self, other): + if not isinstance(other, _UniqueIntervals): + return False + return jnp.array_equal(self.lower_bounds, other.lower_bounds) & jnp.array_equal( + self.upper_bounds, other.upper_bounds + ) + + # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 @@ -897,3 +943,4 @@ def __eq__(self, other): zero_sum = _ZeroSum union_of_closed_intervals = _UnionOfClosedIntervals intersection_of_closed_intervals = _IntersectionOfClosedIntervals +unique_intervals = _UniqueIntervals diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 42dd26fb1..7580eaa40 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1650,3 +1650,15 @@ def _transform_to_union_of_intervals(constraint): ExpTransform(), ] ) + + +@biject_to.register(constraints.unique_intervals) +def _transform_to_unique_intervals(constraint): + return ComposeTransform( + [ + ReshapeTransform( + forward_shape=constraint.lower_bounds.shape, + inverse_shape=(constraint.lower_bounds.size,), + ), + ] + ) diff --git a/test/test_constraints.py b/test/test_constraints.py index 580b938f7..eb23efbdf 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -78,6 +78,9 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): (_a((-100, -130, -150, -250)), _a((200, 100, 250, 100))), dict(), ), + "unique_intervals": T( + constraints.unique_intervals, (_a((-10, -8, -6, -4)), _a((4, 6, 8, 10))), dict() + ), } # TODO: BijectorConstraint diff --git a/test/test_transforms.py b/test/test_transforms.py index aadaee48f..89942ebfe 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -395,6 +395,10 @@ def test_batched_recursive_linear_transform(): ), (4,), ), + ( + constraints.unique_intervals((-10, -8, -6, -4), (4, 6, 8, 10)), + (4,), + ), ], ids=str, )