diff --git a/simplesat/dependency_solver.py b/simplesat/dependency_solver.py index 61ff95d..aaafe03 100644 --- a/simplesat/dependency_solver.py +++ b/simplesat/dependency_solver.py @@ -52,12 +52,7 @@ def _create_rules(self, request): pool.package_id(package) for package in pool.what_provides(requirement) ] - self._policy.add_packages_by_id(requirement_ids) - - # Add installed packages. - self._policy.add_packages_by_id( - [pool.package_id(package) for package in installed_repository] - ) + self._policy.add_requirements(requirement_ids) installed_map = collections.OrderedDict() for package in installed_repository: diff --git a/simplesat/sat/assignment_set.py b/simplesat/sat/assignment_set.py new file mode 100644 index 0000000..c165f8d --- /dev/null +++ b/simplesat/sat/assignment_set.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +import six + +from collections import OrderedDict + + +class _MISSING(object): + pass +MISSING = _MISSING() + + +class AssignmentSet(object): + + """A collection of literals and their assignments.""" + + def __init__(self, assignments=None): + # Changelog is a dict of id -> (original value, new value) + # FIXME: Verify that we really need ordering here + self._data = OrderedDict() + self._orig = {} + self._cached_changelog = None + self._assigned_literals = set() + for k, v in (assignments or {}).items(): + self[k] = v + + def __setitem__(self, key, value): + assert key > 0 + + prev_value = self.get(key) + + if prev_value is not None: + self._assigned_literals.difference_update((key, -key)) + + if value is not None: + self._assigned_literals.add(key if value else -key) + + self._update_diff(key, value) + self._data[key] = value + + def __delitem__(self, key): + self._update_diff(key, MISSING) + prev = self._data.pop(key) + if prev is not None: + self._assigned_literals.difference_update((key, -key)) + + def __getitem__(self, key): + return self._data[key] + + def get(self, key, default=None): + return self._data.get(key, default) + + def __len__(self): + return len(self._data) + + def __contains__(self, key): + return key in self._data + + def items(self): + return list(self._data.items()) + + def iteritems(self): + return six.iteritems(self._data) + + def keys(self): + return list(self._data.keys()) + + def values(self): + return list(self._data.values()) + + def _update_diff(self, key, value): + prev = self._data.get(key, MISSING) + self._orig.setdefault(key, prev) + # If a value changes, dump the cached changelog + self._cached_changelog = None + + def get_changelog(self): + if self._cached_changelog is None: + self._cached_changelog = { + key: (old, new) + for key, old in six.iteritems(self._orig) + for new in [self._data.get(key, MISSING)] + if new != old + } + return self._cached_changelog + + def consume_changelog(self): + old = self.get_changelog() + self._orig = {} + self._cached_changelog = {} + return old + + def copy(self): + new = AssignmentSet() + new._data = self._data.copy() + new._orig = self._orig.copy() + new._assigned_literals = self._assigned_literals.copy() + return new + + def value(self, lit): + """ Return the value of literal in terms of the positive. """ + if lit in self._assigned_literals: + return True + elif -lit in self._assigned_literals: + return False + else: + return None + + @property + def num_assigned(self): + return len(self._assigned_literals) diff --git a/simplesat/sat/clause.py b/simplesat/sat/clause.py index 8a54178..a1014a9 100644 --- a/simplesat/sat/clause.py +++ b/simplesat/sat/clause.py @@ -2,8 +2,6 @@ from collections import OrderedDict -from .utils import value - class Constraint(object): pass @@ -13,6 +11,7 @@ class Clause(Constraint): def __init__(self, lits, learned=False): self.learned = learned + # This maintains the ordering while removing duplicate values self.lits = list(OrderedDict.fromkeys(lits).keys()) def rewatch(self, assignments, lit): @@ -41,7 +40,7 @@ def rewatch(self, assignments, lit): if lits[0] == -lit: lits[0], lits[1] = lits[1], -lit - if value(lits[0], assignments) is True: + if assignments.value(lits[0]) is True: # This clause has been satisfied, add it back to the watch list. No # unit information can be deduced. return None @@ -49,7 +48,7 @@ def rewatch(self, assignments, lit): # Look for another literal to watch, and switch it with lit[1] to keep # the assumption on the watched literals in place. for n, other in enumerate(lits[2:]): - if value(other, assignments) is not False: + if assignments.value(other) is not False: # Found a new literal that could serve as a watch. lits[1], lits[n + 2] = other, -lit return None diff --git a/simplesat/sat/minisat.py b/simplesat/sat/minisat.py index 0d80507..99f57f6 100644 --- a/simplesat/sat/minisat.py +++ b/simplesat/sat/minisat.py @@ -8,9 +8,9 @@ from six.moves import range +from .assignment_set import AssignmentSet from .clause import Clause from .policy import DefaultPolicy -from .utils import value class MiniSATSolver(object): @@ -41,7 +41,7 @@ def __init__(self, policy=None): self.clauses = [] self.watches = defaultdict(list) - self.assignments = OrderedDict() + self.assignments = AssignmentSet() # A list of literals which become successively true (because of direct # assignment, or by unit propagation). @@ -64,7 +64,6 @@ def __init__(self, policy=None): # Whether the system is satisfiable. self.status = None - # self._policy = policy or DefaultPolicy() def add_clause(self, clause): @@ -116,7 +115,7 @@ def propagate(self): if unit is not None: # TODO Refactor this to take into account the return value # of enqueue(). - if value(unit, self.assignments) is False: + if self.assignments.value(unit) is False: # Conflict. Clear the queue and re-insert the remaining # unwatched clauses into the watch list. self.prop_queue.clear() @@ -130,7 +129,7 @@ def propagate(self): def enqueue(self, lit, cause=None): """ Enqueue a new true literal. """ - status = value(lit, self.assignments) + status = self.assignments.value(lit) if status is not None: # Known fact. Don't enqueue, but return whether this fact # contradicts the earlier assignment. @@ -176,7 +175,7 @@ def search(self): def validate(self, solution_map): """Check whether a given set of assignments solves this SAT problem. """ - solution_literals = {(+1 if status else -1) * variable + solution_literals = {variable if status else -variable for variable, status in solution_map.items()} for clause in self.clauses: if len(set(clause.lits) & solution_literals) == 0: @@ -285,8 +284,7 @@ def assume(self, lit): def number_assigned(self): """ Return the number of currently assigned variables. """ - return len([value for value in self.assignments.values() - if value is not None]) + return self.assignments.num_assigned @property def number_variables(self): diff --git a/simplesat/sat/policy.py b/simplesat/sat/policy.py index 1dfebff..f936605 100644 --- a/simplesat/sat/policy.py +++ b/simplesat/sat/policy.py @@ -1,39 +1,179 @@ import abc +from collections import Counter +from functools import partial + import six from enstaller.collections import DefaultOrderedDict +from .priority_queue import PriorityQueue + class IPolicy(six.with_metaclass(abc.ABCMeta)): + + @abc.abstractmethod + def add_requirements(self, package_ids): + """ Submit packages to the policy for consideration. + """ + @abc.abstractmethod def get_next_package_id(self, assignments, clauses): """ Returns a undecided variable (i.e. integer > 0) for the given sets - of assignements and clauses. + of assignments and clauses. + + Parameters + ---------- + assignments : OrderedDict + The current assignments of each literal. Keys are variables + (integer > 0) and values are one of (True, False, None). + clauses : List of Clause + The collection of Clause objects to satisfy. """ class DefaultPolicy(IPolicy): + + def __init__(self, *args): + pass + + def add_requirements(self, assignments): + pass + def get_next_package_id(self, assignments, _): # Given a dictionary of partial assignments, get an undecided variable # to be decided next. - undecided = [ + undecided = ( package_id for package_id, status in six.iteritems(assignments) if status is None - ] - return undecided[0] + ) + return next(undecided) + + +class PolicyLogger(IPolicy): + def __init__(self, PolicyFactory, pool, installed_repository): + self._policy = PolicyFactory(pool, installed_repository) + self._log_installed = list(installed_repository.iter_packages()) + self._log_suggestions = [] + self._log_required = [] + self._log_pool = pool + self._log_assignment_changes = [] + + def get_next_package_id(self, assignments, clauses): + self._log_assignment_changes.append(assignments.get_changelog()) + pkg_id = self._policy.get_next_package_id(assignments, clauses) + self._log_suggestions.append(pkg_id) + return pkg_id + + def add_requirements(self, package_ids): + self._log_required.extend(package_ids) + self._policy.add_requirements(package_ids) + + def _log_histogram(self, pkg_ids): + c = Counter(pkg_ids) + lines = ( + "{:>25} {}".format(self._log_pretty_pkg_id(k), v) + for k, v in sorted(c.items(), key=lambda p: p[1]) + ) + pretty = '\n'.join(lines) + return c, pretty + + def _log_pretty_pkg_id(self, pkg_id): + p = self._log_pool._id_to_package[pkg_id] + return '{} {}'.format(p.name, p.version) + + +class _InstalledFirstPolicy(IPolicy): + + CURRENT = int(-2e7) + REQUIRED = int(-1e7) + + def __init__(self, pool, installed_repository): + self._pool = pool + self._installed_pkg_ids = set( + pool.package_id(package) for package in + installed_repository.iter_packages() + ) + self._seen_pkg_ids = set() + self._required_pkg_ids = set() + self._pkg_id_priority = {} + self._unassigned_pkg_ids = PriorityQueue() + + def _update_pkg_id_priority(self, assignments=None): + pkg_id_priority = {} + if assignments is not None: + self._seen_pkg_ids.update(assignments.keys()) + + ordered_pkg_ids = sorted( + self._seen_pkg_ids, + key=lambda p: self._pool._id_to_package[p].version, + reverse=True, + ) + + for priority, pkg_id in enumerate(ordered_pkg_ids): + + # Determine the new priority of this pkg + if pkg_id in self._required_pkg_ids: + priority += self.REQUIRED + elif pkg_id in self._installed_pkg_ids: + priority += self.CURRENT + + if pkg_id in self._unassigned_pkg_ids: + self._unassigned_pkg_ids.push(pkg_id, priority=priority) + + pkg_id_priority[pkg_id] = priority + self._pkg_id_priority = pkg_id_priority + + def add_requirements(self, package_ids): + self._required_pkg_ids.update(package_ids) + self._seen_pkg_ids.update(package_ids) + self._update_pkg_id_priority() + + def get_next_package_id(self, assignments, clauses): + self._update_cache_from_assignments(assignments) + + # Grab the most interesting looking currently unassigned id + return self._unassigned_pkg_ids.peek() + + def _update_cache_from_assignments(self, assignments): + changelog = assignments.consume_changelog() + + if not self._seen_pkg_ids.issuperset(changelog): + self._update_pkg_id_priority(assignments) + + for key, (old, new) in six.iteritems(changelog): + if new is None: + # Newly unassigned + priority = self._pkg_id_priority[key] + self._unassigned_pkg_ids.push(key, priority=priority) + elif old is None: + # No longer unassigned (because new is not None) + self._unassigned_pkg_ids.remove(key) + # The remaining case is either True -> False or False -> True, which is + # probably not possible and we don't care about anyway, or + # MISSING -> (True|False) + + # A very cheap sanity check + ours = len(self._unassigned_pkg_ids) + theirs = len(assignments) - assignments.num_assigned + assert ours == theirs, "We failed to track variable assignments" + + +InstalledFirstPolicy = partial(PolicyLogger, _InstalledFirstPolicy) + + +class OldInstalledFirstPolicy(IPolicy): -class InstalledFirstPolicy(IPolicy): def __init__(self, pool, installed_repository): self._pool = pool - self._decision_set = set() self._installed_map = set( pool.package_id(package) for package in installed_repository.iter_packages() ) + self._decision_set = self._installed_map.copy() - def add_packages_by_id(self, package_ids): + def add_requirements(self, package_ids): # TODO Just make this add_requirement. for package_id in package_ids: self._decision_set.add(package_id) @@ -82,12 +222,16 @@ def _group_packages_by_name(self, decision_set): def _handle_empty_decision_set(self, assignments, clauses): # TODO inefficient and verbose + + # The assignments with value None unassigned_ids = set( literal for literal, status in six.iteritems(assignments) if status is None ) + # The rest (true and false) assigned_ids = set(assignments.keys()) - unassigned_ids + # Magnitude is literal, sign is Truthiness signed_assignments = set() for variable in assigned_ids: if assignments[variable]: @@ -97,13 +241,14 @@ def _handle_empty_decision_set(self, assignments, clauses): for clause in clauses: # TODO Need clause.undecided_literals property - if signed_assignments.intersection(clause.lits): + if not signed_assignments.isdisjoint(clause.lits): # Clause is true continue literals = clause.lits undecided = unassigned_ids.intersection(literals) + # The set of relevant literals still undecided self._decision_set.update(abs(lit) for lit in undecided) if len(self._decision_set) == 0: diff --git a/simplesat/sat/priority_queue.py b/simplesat/sat/priority_queue.py new file mode 100644 index 0000000..0ecb86d --- /dev/null +++ b/simplesat/sat/priority_queue.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from functools import partial +from heapq import heappush, heappop +from itertools import count + + +class _REMOVED_TASK(object): + pass +REMOVED_TASK = _REMOVED_TASK() + + +class PriorityQueue(object): + """ A priority queue implementation that supports reprioritizing or + removing tasks, given that tasks are unique. + + Borrowed from: https://docs.python.org/3/library/heapq.html + """ + + def __init__(self): + # list of entries arranged in a heap + self._pq = [] + # mapping of tasks to entries + self._entry_finder = {} + # unique id genrator for tie-breaking + self._next_id = partial(next, count()) + + def __len__(self): + return len(self._entry_finder) + + def __bool__(self): + return bool(len(self)) + + def __contains__(self, key): + return key in self._entry_finder + + def push(self, task, priority=0): + "Add a new task or update the priority of an existing task" + return self._push(priority, self._next_id(), task) + + def peek(self): + """ Return the task with the lowest priority. + + This will pop and repush if a REMOVED task is found. + """ + if not self._pq: + raise KeyError('peek from an empty priority queue') + entry = self._pq[0] + if entry[-1] is REMOVED_TASK: + entry = self._pop() + self._push(*entry) + return entry[-1] + + def pop(self): + 'Remove and return the lowest priority task. Raise KeyError if empty.' + _, _, task = self._pop() + return task + + def pop_many(self, n=None): + """ Return a list of length n of popped elements. If n is not + specified, pop the entire queue. """ + if n is None: + n = len(self) + result = [] + for _ in range(n): + result.append(self.pop()) + return result + + def discard(self, task): + "Remove an existing task if present. If not, do nothing." + try: + self.remove(task) + except KeyError: + pass + + def remove(self, task): + "Remove an existing task. Raise KeyError if not found." + entry = self._entry_finder.pop(task) + entry[-1] = REMOVED_TASK + + def _pop(self): + while self._pq: + entry = heappop(self._pq) + if entry[-1] is not REMOVED_TASK: + del self._entry_finder[entry[-1]] + return entry + raise KeyError('pop from an empty priority queue') + + def _push(self, priority, task_id, task): + if task in self: + o_priority, _, o_task = self._entry_finder[task] + # Still check the task, which might now be REMOVED + if priority == o_priority and task == o_task: + # We're pushing something we already have, do nothing + return + else: + # Make space for the new entry + self.remove(task) + entry = [priority, task_id, task] + self._entry_finder[task] = entry + heappush(self._pq, entry) diff --git a/simplesat/sat/tests/test_assignment_set.py b/simplesat/sat/tests/test_assignment_set.py new file mode 100644 index 0000000..59128c8 --- /dev/null +++ b/simplesat/sat/tests/test_assignment_set.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest + +from ..assignment_set import AssignmentSet, MISSING + + +class TestAssignmentSet(unittest.TestCase): + + def test_starts_empty(self): + AS = AssignmentSet() + self.assertEqual(AS.num_assigned, 0) + self.assertEqual(len(AS), 0) + self.assertEqual({}, AS.get_changelog()) + self.assertEqual([], AS.keys()) + self.assertEqual([], AS.values()) + self.assertEqual([], AS.items()) + + def test_num_assigned(self): + AS = AssignmentSet() + + AS[1] = None + self.assertEqual(AS.num_assigned, 0) + + AS[2] = True + self.assertEqual(AS.num_assigned, 1) + + AS[1] = False + self.assertEqual(AS.num_assigned, 2) + + AS[2] = None + self.assertEqual(AS.num_assigned, 1) + + AS[2] = True + self.assertEqual(AS.num_assigned, 2) + + del AS[1] + self.assertEqual(AS.num_assigned, 1) + + AS[2] = False + self.assertEqual(AS.num_assigned, 1) + + AS[2] = None + self.assertEqual(AS.num_assigned, 0) + + del AS[2] + self.assertEqual(AS.num_assigned, 0) + + def test_container(self): + AS = AssignmentSet() + + AS[1] = True + AS[2] = False + + self.assertIn(1, AS) + self.assertNotIn(3, AS) + + AS[4] = None + AS[3] = True + AS[5] = None + + self.assertIn(3, AS) + self.assertIn(5, AS) + + del AS[5] + self.assertNotIn(5, AS) + + expected = [(1, True), (2, False), (4, None), (3, True)] + + manual_result = list(zip(AS.keys(), AS.values())) + self.assertEqual(AS.items(), expected) + self.assertEqual(manual_result, expected) + self.assertEqual(len(AS), len(expected)) + + def test_copy(self): + + AS = AssignmentSet() + + AS[1] = None + AS[2] = True + AS[3] = None + AS[4] = False + AS[5] = True + + expected = { + 1: None, + 2: True, + 3: None, + 4: False, + 5: True, + } + + copied = AS.copy() + + self.assertIsNot(copied._data, AS._data) + self.assertEqual(copied._data, expected) + + expected = {k: MISSING for k in expected} + + self.assertIsNot(copied._orig, AS._orig) + self.assertEqual(copied._orig, expected) + + self.assertEqual(copied.num_assigned, AS.num_assigned) + + del AS[2] + self.assertIn(2, copied) + + def test_value(self): + AS = AssignmentSet() + + AS[1] = False + AS[2] = True + AS[3] = None + + self.assertTrue(AS.value(-1)) + self.assertTrue(AS.value(2)) + + self.assertFalse(AS.value(1)) + self.assertFalse(AS.value(-2)) + + self.assertIs(AS.value(3), None) + self.assertIs(AS.value(-3), None) + + def test_changelog(self): + + AS = AssignmentSet() + + AS[1] = None + + expected = {1: (MISSING, None)} + self.assertEqual(AS.get_changelog(), expected) + + AS[2] = True + expected[2] = (MISSING, True) + self.assertEqual(AS.get_changelog(), expected) + + AS[2] = False + expected[2] = (MISSING, False) + self.assertEqual(AS.get_changelog(), expected) + + del AS[2] + del expected[2] + self.assertEqual(AS.get_changelog(), expected) + + # Keep should preserve the log + self.assertEqual(AS.get_changelog(), expected) + self.assertEqual(AS.get_changelog(), expected) + + # Otherwise clear the log + log = AS.consume_changelog() + self.assertEqual(log, expected) + self.assertEqual(AS.get_changelog(), {}) + + # Test when something is already present + AS[1] = False + expected = {1: (None, False)} + self.assertEqual(AS.get_changelog(), expected) + + del AS[1] + expected = {1: (None, MISSING)} + self.assertEqual(AS.get_changelog(), expected) diff --git a/simplesat/sat/tests/test_minisat.py b/simplesat/sat/tests/test_minisat.py index 64e1a79..8303548 100644 --- a/simplesat/sat/tests/test_minisat.py +++ b/simplesat/sat/tests/test_minisat.py @@ -1,11 +1,9 @@ import unittest -from collections import OrderedDict - import mock import six -from ..utils import value +from ..assignment_set import AssignmentSet from ..clause import Clause from ..minisat import MiniSATSolver @@ -33,7 +31,7 @@ def zm01_solver(add_conflict=False): if add_conflict: s.add_clause(Clause([-18, -3, -19])) - s.assignments = {k: None for k in range(20)} + s.assignments = AssignmentSet({k: None for k in range(1, 20)}) # Load up the assignments from lower decision levels. The call to # assume() will enter a new decision level and enqueue. @@ -51,7 +49,7 @@ class TestClause(unittest.TestCase): def test_rewatch(self): # Given c = Clause([1, -2, 5]) - assignments = {1: False, 2: None, 5: None} + assignments = AssignmentSet({1: False, 2: None, 5: None}) # When unit = c.rewatch(assignments, -1) @@ -63,7 +61,7 @@ def test_rewatch(self): def test_rewatch_true(self): # Given c = Clause([1, -2, 5]) - assignments = {1: True, 2: None, 5: None} + assignments = AssignmentSet({1: True, 2: None, 5: None}) # When unit = c.rewatch(assignments, -1) @@ -75,7 +73,7 @@ def test_rewatch_true(self): def test_rewatch_unit(self): # Given c = Clause([1, -2, 5]) - assignments = {1: False, 2: True, 5: False} + assignments = AssignmentSet({1: False, 2: True, 5: False}) # When unit = c.rewatch(assignments, 2) @@ -161,7 +159,9 @@ def test_propagate_one_level(self, mock_enqueue): s.add_clause(cl2) s.add_clause(cl3) - s.assignments = {1: None, 2: None, 4: None, 5: None, 7: None} + s.assignments = AssignmentSet( + {1: None, 2: None, 4: None, 5: None, 7: None} + ) # When s.assignments[2] = False # Force 2 to be false. @@ -190,7 +190,7 @@ def test_propagate_with_unit_info(self, mock_enqueue): s.add_clause(cl1) s.add_clause(cl2) - s.assignments = {1: None, 2: None, 4: None, 5: None} + s.assignments = AssignmentSet({1: None, 2: None, 4: None, 5: None}) # When s.assignments[2] = False # Force 2 to be false. @@ -216,7 +216,7 @@ def test_propagate_conflict(self): s.add_clause(cl1) s.add_clause(cl2) - s.assignments = {1: True, 2: None, 3: None, 4: None} + s.assignments = AssignmentSet({1: True, 2: None, 3: None, 4: None}) # When s.assignments[2] = False # Force 2 to be false. @@ -235,13 +235,13 @@ def test_setup_does_not_overwrite_assignments(self): s = MiniSATSolver() s.add_clause([-1]) s.add_clause([1, 2, 3]) - expected_assignments = OrderedDict([(1, False), (2, None), (3, None)]) + expected_assignments = [(1, False), (2, None), (3, None)] # When s._setup_assignments() # Then - self.assertEqual(s.assignments, expected_assignments) + self.assertEqual(list(s.assignments.items()), expected_assignments) def _assertWatchesNotTrue(self, watches, assignments): for watch, clauses in watches.items(): @@ -252,7 +252,7 @@ def _assertWatchesNotTrue(self, watches, assignments): def test_enqueue(self): # Given s = MiniSATSolver() - s.assignments = {1: True, 2: None} + s.assignments = AssignmentSet({1: True, 2: None}) # When / then status = s.enqueue(1) @@ -270,7 +270,7 @@ def test_propagation_with_queue(self): cl2 = Clause([1, 3, 4]) s.add_clause(cl1) s.add_clause(cl2) - s.assignments = {1: None, 2: None, 3: None, 4: None} + s.assignments = AssignmentSet({1: None, 2: None, 3: None, 4: None}) # When s.enqueue(-2) @@ -278,7 +278,8 @@ def test_propagation_with_queue(self): # Then self.assertIsNone(conflict) - self.assertEqual(s.assignments, {1: True, 2: False, 3: None, 4: None}) + self.assertEqual(s.assignments._data, + {1: True, 2: False, 3: None, 4: None}) self.assertEqual(s.trail, [-2, 1]) six.assertCountEqual(self, s.watches[-1], [cl1, cl2]) six.assertCountEqual(self, s.watches[-2], [cl1]) @@ -293,7 +294,7 @@ def test_propagation_with_queue_multiple_implications(self): s.add_clause(cl1) s.add_clause(cl2) s.add_clause(cl3) - s.assignments = {1: None, 2: None, 3: None, 4: None} + s.assignments = AssignmentSet({1: None, 2: None, 3: None, 4: None}) # When s.enqueue(-1) @@ -301,7 +302,7 @@ def test_propagation_with_queue_multiple_implications(self): # Then self.assertIsNone(conflict) - self.assertEqual(s.assignments, + self.assertEqual(s.assignments._data, {1: False, 2: False, 3: False, 4: False}) self.assertEqual(s.trail, [-1, -2, -3, -4]) @@ -318,7 +319,7 @@ def test_propagation_with_queue_conflicted(self): s.add_clause(cl1) s.add_clause(cl2) s.add_clause(cl3) - s.assignments = {1: None, 2: None, 3: None, 4: True} + s.assignments = AssignmentSet({1: None, 2: None, 3: None, 4: True}) # When s.enqueue(-1) @@ -350,9 +351,9 @@ def test_propagate_zm01(self): last = s.trail_lim[-1] six.assertCountEqual(self, s.trail[last:], - [11, -12, 16, -2, -10, 1, 3, -5, 18]) + [11, -12, 16, -2, -10, 1, 3, -5, 18]) - expected_assignments = { + expected = { 1: True, 2: False, 3: True, @@ -363,20 +364,20 @@ def test_propagate_zm01(self): 16: True, 18: True } - for var, _value in expected_assignments.items(): - self.assertEqual(s.assignments[var], _value) + for lit, val in expected.items(): + self.assertEqual(s.assignments[lit], val) def test_undo_one(self): # Given s = MiniSATSolver() s.trail = [1, 2, -3] - s.assignments = {1: None, 2: None, 3: True} + s.assignments = AssignmentSet({1: None, 2: None, 3: True}) # When s.undo_one() # Then - self.assertEqual(s.assignments, {1: None, 2: None, 3: None}) + self.assertEqual(s.assignments._data, {1: None, 2: None, 3: None}) self.assertEqual(s.trail, [1, 2]) def test_cancel(self): @@ -410,7 +411,7 @@ def test_cancel_until(self): def test_assume_cancel_roundtrip(self): # Given s = MiniSATSolver() - s.assignments = {1: None} + s.assignments = AssignmentSet({1: None}) # When / then s.assume(-1) @@ -452,7 +453,7 @@ def test_cancel_zm01(self): self.assertIsNone(s.reason[var]) for lit, clauses in s.watches.items(): if len(clauses) > 2: - self.assertNotEqual(value(-lit, s.assignments), False) + self.assertNotEqual(s.assignments.value(-lit), False) def test_analyze_same_level(self): # Given diff --git a/simplesat/sat/tests/test_minisat_problems.py b/simplesat/sat/tests/test_minisat_problems.py index 56c9250..305f0ed 100644 --- a/simplesat/sat/tests/test_minisat_problems.py +++ b/simplesat/sat/tests/test_minisat_problems.py @@ -2,7 +2,6 @@ from simplesat.examples.van_der_waerden import van_der_waerden -from ..utils import value from ..clause import Clause from ..minisat import MiniSATSolver @@ -11,8 +10,8 @@ class TestMinisatProblems(unittest.TestCase): """Run the Minisat solver on a range of test problems. """ def test_unit_assignments_are_remembered(self): - # Regression test for #31 (calling _setup_assignments() after enqueueing unit - # clauses caused unit assignments to be erased). + # Regression test for #31 (calling _setup_assignments() after + # enqueueing unit clauses caused unit assignments to be erased). # Given s = MiniSATSolver() @@ -46,7 +45,8 @@ def test_no_assumptions(self): sol = s.search() # Then - self.assertEqual(sol, {1: True, 2: True, 3: True, 4: True}) + result = dict(sol.items()) + self.assertEqual(result, {1: True, 2: True, 3: True, 4: True}) def test_one_assumption(self): # A simple problem with only unit propagation, where one additional @@ -64,13 +64,14 @@ def test_one_assumption(self): sol = s.search() # Then - self.assertEqual(sol, {1: True, 2: True, 3: True, 4: True}) + result = dict(sol.items()) + self.assertEqual(result, {1: True, 2: True, 3: True, 4: True}) def check_solution(clauses, solution): for clause in clauses: for lit in clause: - if value(lit, solution): + if solution.value(lit): # Clause is satisfied. break else: diff --git a/simplesat/sat/tests/test_priority_queue.py b/simplesat/sat/tests/test_priority_queue.py new file mode 100644 index 0000000..93634ab --- /dev/null +++ b/simplesat/sat/tests/test_priority_queue.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import random +import unittest + +from ..priority_queue import PriorityQueue, REMOVED_TASK + + +def assert_raises(exception, f, *args, **kwargs): + try: + f(*args, **kwargs) + except exception: + return + msg = "{}(*{}, **{}) did not raise {}" + raise AssertionError(msg.format(f, args, kwargs, exception)) + + +class TestPriorityQueue(unittest.TestCase): + + def setUp(self): + self.N = 250 + self.priorities = random.sample(range(self.N * 2), self.N) + random.shuffle(self.priorities) + + def test_insert_pop(self): + pq = PriorityQueue() + items = list(enumerate(self.priorities)) + + for p, task in items: + pq.push(task, priority=p) + + result = pq.pop_many() + + self.assertEqual(len(pq), 0) + self.assertFalse(pq) + + expected = [x[1] for x in sorted(items)] + + self.assertEqual(result, expected) + + def test_len_remove_discard(self): + pq = PriorityQueue() + for t, p in enumerate(self.priorities): + self.assertEqual(len(pq), t) + pq.push(t, priority=p) + self.assertEqual(len(pq), self.N) + + for t in range(self.N): + self.assertEqual(len(pq), self.N - t) + if t % 2: + pq.remove(t) + else: + pq.discard(t) + self.assertEqual(len(pq), 0) + + def test_len_pop(self): + pq = PriorityQueue() + for t, p in enumerate(self.priorities): + self.assertEqual(len(pq), t) + pq.push(t, priority=p) + self.assertEqual(len(pq), self.N) + + for t in range(self.N): + self.assertEqual(len(pq), self.N - t) + pq.pop() + self.assertEqual(len(pq), 0) + + def test_bool_full_empty(self): + pq = PriorityQueue() + self.assertFalse(pq) + + pq.push(0) + self.assertTrue(pq) + + pq.push(1) + self.assertTrue(pq) + + pq.discard(2) + self.assertTrue(pq) + + pq.discard(1) + self.assertTrue(pq) + + pq.remove(0) + self.assertFalse(pq) + + def test_tasks_are_unique(self): + pq = PriorityQueue() + + pq.push(0) + self.assertEqual(1, len(pq)) + + pq.push(0) + self.assertEqual(1, len(pq)) + + pq.push(1) + self.assertEqual(2, len(pq)) + + pq.push(0) + self.assertEqual(2, len(pq)) + + pq.remove(0) + self.assertEqual(1, len(pq)) + + def test_peek(self): + pq = PriorityQueue() + + for i in range(100): + pq.push(i, priority=i) + + self.assertEqual(0, pq.peek()) + + pq.discard(0) + self.assertEqual(1, pq.peek()) + for i in range(2, 80): + pq.remove(i) + + self.assertEqual(1, pq.peek()) + self.assertEqual(1, pq.pop()) + + self.assertEqual(80, pq.peek()) + + def test_no_pop_removed(self): + pq = PriorityQueue() + pq.push(0) + pq.remove(0) + assert_raises( + KeyError, + pq.pop, + ) + + def test_same_priority_stable(self): + pq = PriorityQueue() + expected = self.priorities + for t in expected: + pq.push(t) + result = pq.pop_many() + self.assertEqual(expected, result) + + def test_empty_throws_exceptions(self): + pq = PriorityQueue() + assert_raises( + KeyError, + pq.pop, + ) + assert_raises( + KeyError, + pq.peek, + ) + assert_raises( + KeyError, + pq.remove, + 0, + ) + self.assertEqual([], pq.pop_many()) + + def test_contains(self): + pq = PriorityQueue() + + pq.push(0) + self.assertIn(0, pq) + pq.push(0) + self.assertIn(0, pq) + + pq.push(1) + self.assertIn(0, pq) + self.assertIn(1, pq) + + pq.remove(0) + self.assertIn(1, pq) + self.assertNotIn(0, pq) + + pq.pop() + self.assertNotIn(0, pq) + self.assertNotIn(1, pq) + + def test_barrage(self): + pq = PriorityQueue() + + has = set() + + for t, p in enumerate(self.priorities): + pq.push(t, priority=p) + has.add(t) + + for i in range(20000): + + self.assertEqual(len(pq), len(has)) + + t = random.choice(self.priorities) + + if t not in has: + self.assertNotIn(t, pq) + assert_raises( + KeyError, + pq.remove, + t + ) + pq.discard(t) + else: + if random.random() > 0.5: + pq.remove(t) + has.remove(t) + + if random.random() > 0.5: + pq.push(t, priority=random.randrange(self.N)) + has.add(t) + + if random.random() > 0.75 and has: + t = has.pop() + pq.remove(t) + + if random.random() > 0.75 and has: + t = pq.peek() + expected = next( + t for t in sorted(pq._pq) + if t[2] is not REMOVED_TASK + )[2] + self.assertEqual(t, expected) + self.assertEqual(t, pq.pop()) + has.remove(t) diff --git a/simplesat/sat/utils.py b/simplesat/sat/utils.py deleted file mode 100644 index f53cb2e..0000000 --- a/simplesat/sat/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -def value(lit, assignments): - """ Value of a literal given variable assignments. - """ - status = assignments.get(abs(lit)) - if status is None: - return None - is_conjugated = lit < 0 - return is_conjugated is not status diff --git a/simplesat/tests/crash.yaml b/simplesat/tests/crash.yaml index 068aac9..6195fcd 100644 --- a/simplesat/tests/crash.yaml +++ b/simplesat/tests/crash.yaml @@ -30,3 +30,16 @@ packages: request: - operation: "install" requirement: "numpy >= 1.9.2" + +decisions: + 0: libgfortran 3.0.0-2 + 1: MKL 10.3-1 + 2: numpy 1.9.2-1 + +transaction: + - kind: "install" + package: "libgfortran 3.0.0-2" + - kind: "install" + package: "MKL 10.3-1" + - kind: "install" + package: "numpy 1.9.2-1" diff --git a/simplesat/tests/test_scenarios_policy.py b/simplesat/tests/test_scenarios_policy.py index 315857d..b531b48 100644 --- a/simplesat/tests/test_scenarios_policy.py +++ b/simplesat/tests/test_scenarios_policy.py @@ -8,6 +8,32 @@ from .common import Scenario +def _pretty_operations(ops): + ret = [] + for p in ops: + try: + name = p.package.name + version = str(p.package.version) + except AttributeError: + name, version = p.package.split() + ret.append((name, version)) + return ret + + +def _pkg_delta(operations, scenario_operations): + pkg_delta = {} + for p in scenario_operations: + name, version = _pretty_operations([p]) + pkg_delta.setdefault(name, [None, None])[0] = version + for p in operations: + name, version = _pretty_operations([p]) + pkg_delta.setdefault(name, [None, None])[1] = version + for n, v in pkg_delta.items(): + if v[0] == v[1]: + pkg_delta.pop(n) + return pkg_delta + + class ScenarioTestAssistant(object): def _check_solution(self, filename): @@ -33,7 +59,8 @@ def _check_solution(self, filename): scenario.operations) def assertEqualOperations(self, operations, scenario_operations): - for i, (left, right) in enumerate(zip(operations, scenario_operations)): + pairs = zip(operations, scenario_operations) + for i, (left, right) in enumerate(pairs): if not type(left) == type(right): msg = "Item {0!r} differ in kinds: {1!r} vs {2!r}" self.fail(msg.format(i, type(left), type(right))) @@ -48,8 +75,8 @@ def assertEqualOperations(self, operations, scenario_operations): self.fail("Length of operations differ") -class TestNoInstallSet(TestCase, ScenarioTestAssistant): - @expectedFailure +class TestNoInstallSet(ScenarioTestAssistant, TestCase): + def test_crash(self): self._check_solution("crash.yaml") @@ -59,6 +86,7 @@ def test_simple_numpy(self): def test_ipython(self): self._check_solution("ipython.yaml") + @expectedFailure def test_iris(self): self._check_solution("iris.yaml")