diff --git a/claripy/__init__.py b/claripy/__init__.py index 12ab00aae..a89ea7368 100644 --- a/claripy/__init__.py +++ b/claripy/__init__.py @@ -1,7 +1,16 @@ from __future__ import annotations from claripy import algorithm, ast -from claripy.algorithm import is_false, is_true, simplify +from claripy.algorithm import ( + cardinality, + constraint_to_si, + is_false, + is_true, + multivalued, + simplify, + singlevalued, + stride, +) from claripy.annotation import Annotation, RegionAnnotation, SimplificationAvoidanceAnnotation from claripy.ast.bool import ( And, @@ -10,7 +19,6 @@ If, Not, Or, - constraint_to_si, false, ite_cases, ite_dict, @@ -210,4 +218,8 @@ "SolverReplacement", "SolverStrings", "SolverVSA", + "cardinality", + "multivalued", + "singlevalued", + "stride", ) diff --git a/claripy/algorithm/__init__.py b/claripy/algorithm/__init__.py index f2f139ed0..3f7299c7f 100644 --- a/claripy/algorithm/__init__.py +++ b/claripy/algorithm/__init__.py @@ -2,9 +2,15 @@ from .bool_check import is_false, is_true from .simplify import simplify +from .vsa_helper import cardinality, constraint_to_si, multivalued, singlevalued, stride __all__ = ( "is_false", "is_true", "simplify", + "singlevalued", + "multivalued", + "cardinality", + "constraint_to_si", + "stride", ) diff --git a/claripy/algorithm/simplify.py b/claripy/algorithm/simplify.py index 9bbd3940b..21335234d 100644 --- a/claripy/algorithm/simplify.py +++ b/claripy/algorithm/simplify.py @@ -1,10 +1,13 @@ from __future__ import annotations import logging +from contextlib import suppress from itertools import chain from typing import TypeVar +import claripy from claripy.ast.base import Base, SimplificationLevel +from claripy.errors import BackendError log = logging.getLogger(__name__) @@ -19,8 +22,11 @@ def simplify(expr: T) -> T: if expr.is_leaf() or expr._simplified == SimplificationLevel.FULL_SIMPLIFY: return expr - simplified = expr._first_backend("simplify") - if simplified is None: + for backend in claripy.backends.all_backends: + with suppress(BackendError): + simplified = backend.simplify(expr) + break + else: log.debug("Unable to simplify expression") return expr diff --git a/claripy/algorithm/vsa_helper.py b/claripy/algorithm/vsa_helper.py new file mode 100644 index 000000000..84d89dc70 --- /dev/null +++ b/claripy/algorithm/vsa_helper.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import claripy + +if TYPE_CHECKING: + from claripy.ast import BV + + +def singlevalued(expr: BV) -> bool: + """Determines if the expression is single-valued.""" + return claripy.backends.vsa.singlevalued(expr) + + +def multivalued(expr: BV) -> bool: + """Determines if the expression is multi-valued.""" + return claripy.backends.vsa.multivalued(expr) + + +def cardinality(expr: BV) -> int: + """Returns the number of possible values the expression can take.""" + return claripy.backends.vsa.cardinality(expr) + + +def stride(expr: BV) -> int: + """Returns the stride of the expression.""" + return claripy.backends.vsa.stride(expr) + + +def constraint_to_si(expr: BV) -> tuple[bool, list[tuple[BV, BV]]]: + """ + Convert a constraint to SI if possible. + + :param expr: The expression to convert. + :return: A tuple of satisfiable and a list of replacements. + """ + + return claripy.backends.vsa.constraint_to_si(expr) diff --git a/claripy/ast/base.py b/claripy/ast/base.py index 5a62444e8..b4b453e7f 100644 --- a/claripy/ast/base.py +++ b/claripy/ast/base.py @@ -1207,17 +1207,6 @@ def ite_excavated(self: T) -> T: # these are convenience operations # - def _first_backend(self, what): - for b in claripy.backends.all_backends: - if b in self._errored: - continue - - try: - return getattr(b, what)(self) - except BackendError: - pass - return None - @property def concrete_value(self): try: @@ -1228,18 +1217,6 @@ def concrete_value(self): except BackendError: return self - @property - def singlevalued(self) -> bool: - return self._first_backend("singlevalued") - - @property - def multivalued(self) -> bool: - return self._first_backend("multivalued") - - @property - def cardinality(self) -> int: - return self._first_backend("cardinality") - @property def concrete(self) -> bool: return not self.symbolic diff --git a/claripy/ast/bool.py b/claripy/ast/bool.py index 4417a936e..9d07ec3d3 100644 --- a/claripy/ast/bool.py +++ b/claripy/ast/bool.py @@ -4,7 +4,6 @@ from functools import lru_cache from typing import TYPE_CHECKING, overload -import claripy from claripy import operations from claripy.algorithm.bool_check import is_false, is_true from claripy.ast.base import ASTCacheKey, Base, _make_name @@ -227,28 +226,3 @@ def reverse_ite_cases(ast): queue.append((And(condition, Not(ast.args[0])), ast.args[2])) else: yield condition, ast - - -def constraint_to_si(expr): - """ - Convert a constraint to SI if possible. - - :param expr: - :return: - """ - - satisfiable = True - replace_list = [] - - satisfiable, replace_list = claripy.backends.vsa.constraint_to_si(expr) - - # Make sure the replace_list are all ast.bvs - for i in range(len(replace_list)): # pylint:disable=consider-using-enumerate - ori, new = replace_list[i] - if not isinstance(new, Base): - new = claripy.BVS( - new.name, new._bits, min=new._lower_bound, max=new._upper_bound, stride=new._stride, explicit_name=True - ) - replace_list[i] = (ori, new) - - return satisfiable, replace_list diff --git a/claripy/backends/backend_vsa/backend_vsa.py b/claripy/backends/backend_vsa/backend_vsa.py index 7589f016b..392ba4177 100644 --- a/claripy/backends/backend_vsa/backend_vsa.py +++ b/claripy/backends/backend_vsa/backend_vsa.py @@ -470,5 +470,8 @@ def CreateTopStridedInterval(bits, name=None, uninitialized=False): def constraint_to_si(self, expr): return Balancer(self, expr).compat_ret + def stride(self, expr): + return self.convert(expr).stride + BackendVSA.CreateStridedInterval = staticmethod(CreateStridedInterval) diff --git a/claripy/balancer.py b/claripy/balancer.py index cddf04658..2ba497a9b 100644 --- a/claripy/balancer.py +++ b/claripy/balancer.py @@ -5,6 +5,7 @@ import claripy import claripy.backends.backend_vsa as vsa +from claripy.algorithm.vsa_helper import cardinality from .ast.base import Base from .ast.bool import Bool @@ -105,7 +106,7 @@ def _same_bound_bv(self, a): @staticmethod def _cardinality(a): - return a.cardinality if isinstance(a, Base) else 0 + return cardinality(a) if isinstance(a, Base) else 0 @staticmethod def _min(a, signed=False): @@ -176,7 +177,7 @@ def _align_ast(self, a): try: if isinstance(a, BV): return self._align_bv(a) - if isinstance(a, Bool) and len(a.args) == 2 and a.args[1].cardinality > a.args[0].cardinality: + if isinstance(a, Bool) and len(a.args) == 2 and cardinality(a.args[1]) > cardinality(a.args[0]): return self._reverse_comparison(a) return a except ClaripyBalancerError: @@ -275,7 +276,7 @@ def _handleable_truism(t): if len(t.args) < 2: l.debug("can't do anything with an unop bool") return None - if t.args[0].cardinality > 1 and t.args[1].cardinality > 1: + if claripy.cardinality(t.args[0]) > 1 and claripy.cardinality(t.args[1]) > 1: l.debug("can't do anything because we have multiple multivalued guys") return False if t.op == "If": @@ -289,7 +290,7 @@ def _adjust_truism(t): Swap the operands of the truism if the unknown variable is on the right side and the concrete value is on the left side. """ - if t.args[0].cardinality == 1 and t.args[1].cardinality > 1: + if claripy.cardinality(t.args[0]) == 1 and claripy.cardinality(t.args[1]) > 1: return Balancer._reverse_comparison(t) return t @@ -383,7 +384,7 @@ def _balance(self, truism): try: inner_aligned = self._align_truism(truism) - if inner_aligned.args[1].cardinality > 1: + if cardinality(inner_aligned.args[1]) > 1: l.debug("can't do anything because we have multiple multivalued guys") return truism @@ -653,7 +654,7 @@ def _handle_comparison(self, truism): def _handle___eq__(self, truism): lhs, rhs = truism.args - if rhs.cardinality != 1: + if cardinality(rhs) != 1: common = self._same_bound_bv(lhs.intersection(rhs)) mn, mx = self._range(common) self._add_upper_bound(lhs, mx) @@ -667,7 +668,7 @@ def _handle___eq__(self, truism): def _handle___ne__(self, truism): lhs, rhs = truism.args - if rhs.cardinality == 1: + if cardinality(rhs) == 1: val = self._helper.eval(rhs, 1)[0] max_int = vsa.StridedInterval.max_int(len(rhs)) diff --git a/claripy/frontend/replacement_frontend.py b/claripy/frontend/replacement_frontend.py index caceae2cd..6dff27c55 100644 --- a/claripy/frontend/replacement_frontend.py +++ b/claripy/frontend/replacement_frontend.py @@ -5,7 +5,7 @@ import weakref from contextlib import suppress -from claripy import backends +from claripy import backends, cardinality from claripy.ast.base import Base from claripy.ast.bool import BoolV, false from claripy.ast.bv import BVV @@ -282,11 +282,11 @@ def _add(self, constraints, invalidate_cache=True): if not satisfiable: self.add_replacement(rc, false()) for old, new in replacements: - if old.cardinality == 1: + if cardinality(old) == 1: continue rold = self._replacement(old) - if rold.cardinality == 1: + if cardinality(rold) == 1: continue self.add_replacement(old, rold.intersection(new)) diff --git a/tests/test_expression.py b/tests/test_expression.py index 72cc1c8ff..8d1f31191 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -5,6 +5,7 @@ import claripy import claripy.backends +from claripy import cardinality, multivalued, singlevalued class TestExpression(unittest.TestCase): @@ -184,30 +185,30 @@ def test_cardinality(self): n = claripy.BVV(10, 32) m = claripy.BVV(20, 32) - self.assertEqual(y.cardinality, 21) - self.assertEqual(x.cardinality, 2**32) - self.assertEqual(n.cardinality, 1) - self.assertEqual(m.cardinality, 1) - self.assertEqual(n.union(m).cardinality, 2) - self.assertEqual(n.union(y).cardinality, 111) - self.assertEqual(y.intersection(x).cardinality, 21) - self.assertEqual(n.intersection(m).cardinality, 0) - self.assertEqual(y.intersection(m).cardinality, 0) + self.assertEqual(cardinality(y), 21) + self.assertEqual(cardinality(x), 2**32) + self.assertEqual(cardinality(n), 1) + self.assertEqual(cardinality(m), 1) + self.assertEqual(cardinality(n.union(m)), 2) + self.assertEqual(cardinality(n.union(y)), 111) + self.assertEqual(cardinality(y.intersection(x)), 21) + self.assertEqual(cardinality(n.intersection(m)), 0) + self.assertEqual(cardinality(y.intersection(m)), 0) - self.assertTrue(n.singlevalued) - self.assertFalse(n.multivalued) + self.assertTrue(singlevalued(n)) + self.assertFalse(multivalued(n)) - self.assertTrue(y.multivalued) - self.assertFalse(y.singlevalued) + self.assertTrue(multivalued(y)) + self.assertFalse(singlevalued(y)) - self.assertFalse(x.singlevalued) - self.assertTrue(x.multivalued) + self.assertFalse(singlevalued(x)) + self.assertTrue(multivalued(x)) - self.assertFalse(y.union(m).singlevalued) - self.assertTrue(y.union(m).multivalued) + self.assertFalse(singlevalued(y.union(m))) + self.assertTrue(multivalued(y.union(m))) - self.assertFalse(y.intersection(m).singlevalued) - self.assertFalse(y.intersection(m).multivalued) + self.assertFalse(singlevalued(y.intersection(m))) + self.assertFalse(multivalued(y.intersection(m))) def test_if_stuff(self): x = claripy.BVS("x", 32) diff --git a/tests/test_vsa.py b/tests/test_vsa.py index 46dc1fda6..4e2d283f4 100644 --- a/tests/test_vsa.py +++ b/tests/test_vsa.py @@ -83,7 +83,7 @@ def test_reversed_concat(self): def test_simple_cardinality(self): x = claripy.BVS("x", 32).annotate(claripy.annotation.StridedIntervalAnnotation(0xA, 0xA, 0x14)) - assert x.cardinality == 2 + assert claripy.cardinality(x) == 2 def test_south_pole_splitting(self): # south-pole splitting