Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move VSA helper methods to algorithms sub-package #535

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions claripy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from __future__ import annotations

from claripy import algorithm, ast
from claripy.algorithm import is_false, is_true, simplify
from claripy.algorithm import (
cardinality,
constraint_to_si,
is_false,
is_true,
multivalued,
simplify,
singlevalued,
stride,
)
from claripy.annotation import Annotation, RegionAnnotation, SimplificationAvoidanceAnnotation
from claripy.ast.bool import (
And,
Expand All @@ -10,7 +19,6 @@
If,
Not,
Or,
constraint_to_si,
false,
ite_cases,
ite_dict,
Expand Down Expand Up @@ -210,4 +218,8 @@
"SolverReplacement",
"SolverStrings",
"SolverVSA",
"cardinality",
"multivalued",
"singlevalued",
"stride",
)
6 changes: 6 additions & 0 deletions claripy/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from .bool_check import is_false, is_true
from .simplify import simplify
from .vsa_helper import cardinality, constraint_to_si, multivalued, singlevalued, stride

__all__ = (
"is_false",
"is_true",
"simplify",
"singlevalued",
"multivalued",
"cardinality",
"constraint_to_si",
"stride",
)
10 changes: 8 additions & 2 deletions claripy/algorithm/simplify.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from contextlib import suppress
from itertools import chain
from typing import TypeVar

import claripy
from claripy.ast.base import Base, SimplificationLevel
from claripy.errors import BackendError

log = logging.getLogger(__name__)

Expand All @@ -19,8 +22,11 @@ def simplify(expr: T) -> T:
if expr.is_leaf() or expr._simplified == SimplificationLevel.FULL_SIMPLIFY:
return expr

simplified = expr._first_backend("simplify")
if simplified is None:
for backend in claripy.backends.all_backends:
with suppress(BackendError):
simplified = backend.simplify(expr)
break
else:
log.debug("Unable to simplify expression")
return expr

Expand Down
39 changes: 39 additions & 0 deletions claripy/algorithm/vsa_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import claripy

if TYPE_CHECKING:
from claripy.ast import BV


def singlevalued(expr: BV) -> bool:
"""Determines if the expression is single-valued."""
return claripy.backends.vsa.singlevalued(expr)


def multivalued(expr: BV) -> bool:
"""Determines if the expression is multi-valued."""
return claripy.backends.vsa.multivalued(expr)


def cardinality(expr: BV) -> int:
"""Returns the number of possible values the expression can take."""
return claripy.backends.vsa.cardinality(expr)


def stride(expr: BV) -> int:
"""Returns the stride of the expression."""
return claripy.backends.vsa.stride(expr)


def constraint_to_si(expr: BV) -> tuple[bool, list[tuple[BV, BV]]]:
"""
Convert a constraint to SI if possible.

:param expr: The expression to convert.
:return: A tuple of satisfiable and a list of replacements.
"""

return claripy.backends.vsa.constraint_to_si(expr)
23 changes: 0 additions & 23 deletions claripy/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,17 +1207,6 @@ def ite_excavated(self: T) -> T:
# these are convenience operations
#

def _first_backend(self, what):
for b in claripy.backends.all_backends:
if b in self._errored:
continue

try:
return getattr(b, what)(self)
except BackendError:
pass
return None

@property
def concrete_value(self):
try:
Expand All @@ -1228,18 +1217,6 @@ def concrete_value(self):
except BackendError:
return self

@property
def singlevalued(self) -> bool:
return self._first_backend("singlevalued")

@property
def multivalued(self) -> bool:
return self._first_backend("multivalued")

@property
def cardinality(self) -> int:
return self._first_backend("cardinality")

@property
def concrete(self) -> bool:
return not self.symbolic
Expand Down
26 changes: 0 additions & 26 deletions claripy/ast/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import lru_cache
from typing import TYPE_CHECKING, overload

import claripy
from claripy import operations
from claripy.algorithm.bool_check import is_false, is_true
from claripy.ast.base import ASTCacheKey, Base, _make_name
Expand Down Expand Up @@ -227,28 +226,3 @@ def reverse_ite_cases(ast):
queue.append((And(condition, Not(ast.args[0])), ast.args[2]))
else:
yield condition, ast


def constraint_to_si(expr):
"""
Convert a constraint to SI if possible.

:param expr:
:return:
"""

satisfiable = True
replace_list = []

satisfiable, replace_list = claripy.backends.vsa.constraint_to_si(expr)

# Make sure the replace_list are all ast.bvs
for i in range(len(replace_list)): # pylint:disable=consider-using-enumerate
ori, new = replace_list[i]
if not isinstance(new, Base):
new = claripy.BVS(
new.name, new._bits, min=new._lower_bound, max=new._upper_bound, stride=new._stride, explicit_name=True
)
replace_list[i] = (ori, new)

return satisfiable, replace_list
3 changes: 3 additions & 0 deletions claripy/backends/backend_vsa/backend_vsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,5 +470,8 @@ def CreateTopStridedInterval(bits, name=None, uninitialized=False):
def constraint_to_si(self, expr):
return Balancer(self, expr).compat_ret

def stride(self, expr):
return self.convert(expr).stride


BackendVSA.CreateStridedInterval = staticmethod(CreateStridedInterval)
15 changes: 8 additions & 7 deletions claripy/balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import claripy
import claripy.backends.backend_vsa as vsa
from claripy.algorithm.vsa_helper import cardinality

from .ast.base import Base
from .ast.bool import Bool
Expand Down Expand Up @@ -105,7 +106,7 @@ def _same_bound_bv(self, a):

@staticmethod
def _cardinality(a):
return a.cardinality if isinstance(a, Base) else 0
return cardinality(a) if isinstance(a, Base) else 0

@staticmethod
def _min(a, signed=False):
Expand Down Expand Up @@ -176,7 +177,7 @@ def _align_ast(self, a):
try:
if isinstance(a, BV):
return self._align_bv(a)
if isinstance(a, Bool) and len(a.args) == 2 and a.args[1].cardinality > a.args[0].cardinality:
if isinstance(a, Bool) and len(a.args) == 2 and cardinality(a.args[1]) > cardinality(a.args[0]):
return self._reverse_comparison(a)
return a
except ClaripyBalancerError:
Expand Down Expand Up @@ -275,7 +276,7 @@ def _handleable_truism(t):
if len(t.args) < 2:
l.debug("can't do anything with an unop bool")
return None
if t.args[0].cardinality > 1 and t.args[1].cardinality > 1:
if claripy.cardinality(t.args[0]) > 1 and claripy.cardinality(t.args[1]) > 1:
l.debug("can't do anything because we have multiple multivalued guys")
return False
if t.op == "If":
Expand All @@ -289,7 +290,7 @@ def _adjust_truism(t):
Swap the operands of the truism if the unknown variable is on the right side and the concrete value is on the
left side.
"""
if t.args[0].cardinality == 1 and t.args[1].cardinality > 1:
if claripy.cardinality(t.args[0]) == 1 and claripy.cardinality(t.args[1]) > 1:
return Balancer._reverse_comparison(t)
return t

Expand Down Expand Up @@ -383,7 +384,7 @@ def _balance(self, truism):

try:
inner_aligned = self._align_truism(truism)
if inner_aligned.args[1].cardinality > 1:
if cardinality(inner_aligned.args[1]) > 1:
l.debug("can't do anything because we have multiple multivalued guys")
return truism

Expand Down Expand Up @@ -653,7 +654,7 @@ def _handle_comparison(self, truism):

def _handle___eq__(self, truism):
lhs, rhs = truism.args
if rhs.cardinality != 1:
if cardinality(rhs) != 1:
common = self._same_bound_bv(lhs.intersection(rhs))
mn, mx = self._range(common)
self._add_upper_bound(lhs, mx)
Expand All @@ -667,7 +668,7 @@ def _handle___eq__(self, truism):

def _handle___ne__(self, truism):
lhs, rhs = truism.args
if rhs.cardinality == 1:
if cardinality(rhs) == 1:
val = self._helper.eval(rhs, 1)[0]
max_int = vsa.StridedInterval.max_int(len(rhs))

Expand Down
6 changes: 3 additions & 3 deletions claripy/frontend/replacement_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import weakref
from contextlib import suppress

from claripy import backends
from claripy import backends, cardinality
from claripy.ast.base import Base
from claripy.ast.bool import BoolV, false
from claripy.ast.bv import BVV
Expand Down Expand Up @@ -282,11 +282,11 @@ def _add(self, constraints, invalidate_cache=True):
if not satisfiable:
self.add_replacement(rc, false())
for old, new in replacements:
if old.cardinality == 1:
if cardinality(old) == 1:
continue

rold = self._replacement(old)
if rold.cardinality == 1:
if cardinality(rold) == 1:
continue

self.add_replacement(old, rold.intersection(new))
Expand Down
39 changes: 20 additions & 19 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import claripy
import claripy.backends
from claripy import cardinality, multivalued, singlevalued


class TestExpression(unittest.TestCase):
Expand Down Expand Up @@ -184,30 +185,30 @@ def test_cardinality(self):
n = claripy.BVV(10, 32)
m = claripy.BVV(20, 32)

self.assertEqual(y.cardinality, 21)
self.assertEqual(x.cardinality, 2**32)
self.assertEqual(n.cardinality, 1)
self.assertEqual(m.cardinality, 1)
self.assertEqual(n.union(m).cardinality, 2)
self.assertEqual(n.union(y).cardinality, 111)
self.assertEqual(y.intersection(x).cardinality, 21)
self.assertEqual(n.intersection(m).cardinality, 0)
self.assertEqual(y.intersection(m).cardinality, 0)
self.assertEqual(cardinality(y), 21)
self.assertEqual(cardinality(x), 2**32)
self.assertEqual(cardinality(n), 1)
self.assertEqual(cardinality(m), 1)
self.assertEqual(cardinality(n.union(m)), 2)
self.assertEqual(cardinality(n.union(y)), 111)
self.assertEqual(cardinality(y.intersection(x)), 21)
self.assertEqual(cardinality(n.intersection(m)), 0)
self.assertEqual(cardinality(y.intersection(m)), 0)

self.assertTrue(n.singlevalued)
self.assertFalse(n.multivalued)
self.assertTrue(singlevalued(n))
self.assertFalse(multivalued(n))

self.assertTrue(y.multivalued)
self.assertFalse(y.singlevalued)
self.assertTrue(multivalued(y))
self.assertFalse(singlevalued(y))

self.assertFalse(x.singlevalued)
self.assertTrue(x.multivalued)
self.assertFalse(singlevalued(x))
self.assertTrue(multivalued(x))

self.assertFalse(y.union(m).singlevalued)
self.assertTrue(y.union(m).multivalued)
self.assertFalse(singlevalued(y.union(m)))
self.assertTrue(multivalued(y.union(m)))

self.assertFalse(y.intersection(m).singlevalued)
self.assertFalse(y.intersection(m).multivalued)
self.assertFalse(singlevalued(y.intersection(m)))
self.assertFalse(multivalued(y.intersection(m)))

def test_if_stuff(self):
x = claripy.BVS("x", 32)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_reversed_concat(self):

def test_simple_cardinality(self):
x = claripy.BVS("x", 32).annotate(claripy.annotation.StridedIntervalAnnotation(0xA, 0xA, 0x14))
assert x.cardinality == 2
assert claripy.cardinality(x) == 2

def test_south_pole_splitting(self):
# south-pole splitting
Expand Down
Loading