From 66f808702ab62a064cb5d8f5ab922a2c9820ba6d Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Tue, 3 Sep 2024 18:18:42 +0200 Subject: [PATCH] [CP-SAT] internal protocol improvements; bug fixes --- ortools/sat/2d_packing_brute_force.cc | 4 +- ortools/sat/2d_rectangle_presolve.cc | 407 ++++++++++++++++++++++++++ ortools/sat/2d_rectangle_presolve.h | 50 ++++ ortools/sat/BUILD.bazel | 21 +- ortools/sat/clause.cc | 5 +- ortools/sat/clause.h | 8 +- ortools/sat/cp_constraints.cc | 88 ++++-- ortools/sat/cp_constraints.h | 18 +- ortools/sat/cp_model_presolve.cc | 148 ++++++++-- ortools/sat/cp_model_search.cc | 9 +- ortools/sat/cp_model_solver.cc | 6 +- ortools/sat/diffn_util.cc | 91 +++++- ortools/sat/diffn_util.h | 17 +- ortools/sat/integer.cc | 279 ++++++++++++------ ortools/sat/integer.h | 131 ++++++--- ortools/sat/integer_expr.cc | 12 +- ortools/sat/integer_expr.h | 4 +- ortools/sat/integer_search.cc | 67 ++--- ortools/sat/integer_search.h | 1 + ortools/sat/intervals.cc | 27 +- ortools/sat/intervals.h | 9 +- ortools/sat/linear_propagation.cc | 5 +- ortools/sat/linear_propagation.h | 2 +- ortools/sat/lp_utils.cc | 2 +- ortools/sat/pb_constraint.cc | 4 +- ortools/sat/pb_constraint.h | 4 +- ortools/sat/probing.cc | 2 +- ortools/sat/sat_base.h | 25 +- ortools/sat/sat_solver.cc | 6 +- ortools/sat/sat_solver.h | 2 + ortools/sat/scheduling_cuts.cc | 6 +- ortools/sat/symmetry.cc | 5 +- ortools/sat/symmetry.h | 4 +- ortools/sat/util.cc | 21 +- ortools/sat/util.h | 5 +- 35 files changed, 1167 insertions(+), 328 deletions(-) create mode 100644 ortools/sat/2d_rectangle_presolve.cc create mode 100644 ortools/sat/2d_rectangle_presolve.h diff --git a/ortools/sat/2d_packing_brute_force.cc b/ortools/sat/2d_packing_brute_force.cc index 8f258a603e5..d6381a99547 100644 --- a/ortools/sat/2d_packing_brute_force.cc +++ b/ortools/sat/2d_packing_brute_force.cc @@ -681,8 +681,8 @@ BruteForceResult BruteForceOrthogonalPacking( for (const PermutableItem& item : items) { result[item.index] = item.position; } - VLOG_EVERY_N_SEC(3, 3) << "Found a feasible packing by brute force. Dot:\n " - << RenderDot(bounding_box_size, result); + // VLOG_EVERY_N_SEC(3, 3) << "Found a feasible packing by brute force. Dot:\n " + // << RenderDot(bounding_box_size, result); return {.status = BruteForceResult::Status::kFoundSolution, .positions_for_solution = result}; } diff --git a/ortools/sat/2d_rectangle_presolve.cc b/ortools/sat/2d_rectangle_presolve.cc new file mode 100644 index 00000000000..9fefb17ff64 --- /dev/null +++ b/ortools/sat/2d_rectangle_presolve.cc @@ -0,0 +1,407 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/2d_rectangle_presolve.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" + +namespace operations_research { +namespace sat { + +bool PresolveFixed2dRectangles( + absl::Span non_fixed_boxes, + std::vector* fixed_boxes) { + // This implementation compiles a set of areas that cannot be occupied by any + // item, then calls ReduceNumberofBoxes() to use these areas to minimize + // `fixed_boxes`. + bool changed = false; + + IntegerValue original_area = 0; + std::vector fixed_boxes_copy; + if (VLOG_IS_ON(1)) { + for (const Rectangle& r : *fixed_boxes) { + original_area += r.Area(); + } + } + if (VLOG_IS_ON(2)) { + fixed_boxes_copy = *fixed_boxes; + } + + const int original_num_boxes = fixed_boxes->size(); + IntegerValue min_x_size = std::numeric_limits::max(); + IntegerValue min_y_size = std::numeric_limits::max(); + + CHECK(!non_fixed_boxes.empty()); + Rectangle bounding_box = non_fixed_boxes[0].bounding_area; + + for (const RectangleInRange& box : non_fixed_boxes) { + bounding_box.GrowToInclude(box.bounding_area); + min_x_size = std::min(min_x_size, box.x_size); + min_y_size = std::min(min_y_size, box.y_size); + } + + // Fixed items are only useful to constraint where the non-fixed items can be + // placed. This means in particular that any part of a fixed item outside the + // bounding box of the non-fixed items is useless. Clip them. + int new_size = 0; + while (new_size < fixed_boxes->size()) { + Rectangle& rectangle = (*fixed_boxes)[new_size]; + if (rectangle.x_min < bounding_box.x_min) { + rectangle.x_min = bounding_box.x_min; + changed = true; + } + if (rectangle.x_max > bounding_box.x_max) { + rectangle.x_max = bounding_box.x_max; + changed = true; + } + if (rectangle.y_min < bounding_box.y_min) { + rectangle.y_min = bounding_box.y_min; + changed = true; + } + if (rectangle.y_max > bounding_box.y_max) { + rectangle.y_max = bounding_box.y_max; + changed = true; + } + if (rectangle.SizeX() <= 0 || rectangle.SizeY() <= 0) { + // The whole rectangle was outside of the domain, remove it. + std::swap(rectangle, (*fixed_boxes)[fixed_boxes->size() - 1]); + fixed_boxes->resize(fixed_boxes->size() - 1); + continue; + } else { + new_size++; + } + } + + std::vector optional_boxes = *fixed_boxes; + + if (bounding_box.x_min > std::numeric_limits::min() && + bounding_box.y_min > std::numeric_limits::min() && + bounding_box.x_max < std::numeric_limits::max() && + bounding_box.y_max < std::numeric_limits::max()) { + // Add fake rectangles to build a frame around the bounding box. This allows + // to find more areas that must be empty. The frame is as follows: + // +************ + // +...........+ + // +...........+ + // +...........+ + // ************+ + optional_boxes.push_back({.x_min = bounding_box.x_min - 1, + .x_max = bounding_box.x_max, + .y_min = bounding_box.y_min - 1, + .y_max = bounding_box.y_min}); + optional_boxes.push_back({.x_min = bounding_box.x_max, + .x_max = bounding_box.x_max + 1, + .y_min = bounding_box.y_min - 1, + .y_max = bounding_box.y_max}); + optional_boxes.push_back({.x_min = bounding_box.x_min, + .x_max = bounding_box.x_max + 1, + .y_min = bounding_box.y_max, + .y_max = bounding_box.y_max + 1}); + optional_boxes.push_back({.x_min = bounding_box.x_min - 1, + .x_max = bounding_box.x_min, + .y_min = bounding_box.y_min, + .y_max = bounding_box.y_max + 1}); + } + + // All items we added to `optional_boxes` at this point are only to be used by + // the "gap between items" logic below. They are not actual optional boxes and + // should be removed right after the logic is applied. + const int num_optional_boxes_to_remove = optional_boxes.size(); + + // Add a rectangle to `optional_boxes` but respecting that rectangles must + // remain disjoint. + const auto add_box = [&optional_boxes](Rectangle new_box) { + std::vector to_add = {std::move(new_box)}; + for (int i = 0; i < to_add.size(); ++i) { + Rectangle new_box = to_add[i]; + bool is_disjoint = true; + for (const Rectangle& existing_box : optional_boxes) { + if (!new_box.IsDisjoint(existing_box)) { + is_disjoint = false; + for (const Rectangle& disjoint_box : + new_box.SetDifference(existing_box)) { + to_add.push_back(disjoint_box); + } + break; + } + } + if (is_disjoint) { + optional_boxes.push_back(std::move(new_box)); + } + } + }; + + // Now check if there is any space that cannot be occupied by any non-fixed + // item. + std::vector bounding_boxes; + bounding_boxes.reserve(non_fixed_boxes.size()); + for (const RectangleInRange& box : non_fixed_boxes) { + bounding_boxes.push_back(box.bounding_area); + } + std::vector empty_spaces = + FindEmptySpaces(bounding_box, std::move(bounding_boxes)); + for (const Rectangle& r : empty_spaces) { + add_box(r); + } + + // Now look for gaps between objects that are too small to place anything. + for (int i = 1; i < optional_boxes.size(); ++i) { + const Rectangle cur_box = optional_boxes[i]; + for (int j = 0; j < i; ++j) { + const Rectangle& other_box = optional_boxes[j]; + const IntegerValue lower_top = std::min(cur_box.y_max, other_box.y_max); + const IntegerValue higher_bottom = + std::max(other_box.y_min, cur_box.y_min); + const IntegerValue rightmost_left_edge = + std::max(other_box.x_min, cur_box.x_min); + const IntegerValue leftmost_right_edge = + std::min(other_box.x_max, cur_box.x_max); + if (rightmost_left_edge < leftmost_right_edge) { + if (lower_top < higher_bottom && + higher_bottom - lower_top < min_y_size) { + add_box({.x_min = rightmost_left_edge, + .x_max = leftmost_right_edge, + .y_min = lower_top, + .y_max = higher_bottom}); + } + } + if (higher_bottom < lower_top) { + if (leftmost_right_edge < rightmost_left_edge && + rightmost_left_edge - leftmost_right_edge < min_x_size) { + add_box({.x_min = leftmost_right_edge, + .x_max = rightmost_left_edge, + .y_min = higher_bottom, + .y_max = lower_top}); + } + } + } + } + optional_boxes.erase(optional_boxes.begin(), + optional_boxes.begin() + num_optional_boxes_to_remove); + + if (ReduceNumberofBoxes(fixed_boxes, &optional_boxes)) { + changed = true; + } + if (changed && VLOG_IS_ON(1)) { + IntegerValue area = 0; + for (const Rectangle& r : *fixed_boxes) { + area += r.Area(); + } + VLOG_EVERY_N_SEC(1, 1) << "Presolved " << original_num_boxes + << " fixed rectangles (area=" << original_area + << ") into " << fixed_boxes->size() + << " (area=" << area << ")"; + + VLOG_EVERY_N_SEC(2, 2) << "Presolved rectangles:\n" + << RenderDot(bounding_box, fixed_boxes_copy) + << "Into:\n" + << RenderDot(bounding_box, *fixed_boxes) + << (optional_boxes.empty() + ? "" + : absl::StrCat("Unused optional rects:\n", + RenderDot(bounding_box, + optional_boxes))); + } + return changed; +} + +namespace { +struct Edge { + IntegerValue x_start; + IntegerValue y_start; + IntegerValue size; + + enum class EdgePosition { TOP, BOTTOM, LEFT, RIGHT }; + + static Edge GetEdge(const Rectangle& rectangle, EdgePosition pos) { + switch (pos) { + case EdgePosition::TOP: + return {.x_start = rectangle.x_min, + .y_start = rectangle.y_max, + .size = rectangle.SizeX()}; + case EdgePosition::BOTTOM: + return {.x_start = rectangle.x_min, + .y_start = rectangle.y_min, + .size = rectangle.SizeX()}; + case EdgePosition::LEFT: + return {.x_start = rectangle.x_min, + .y_start = rectangle.y_min, + .size = rectangle.SizeY()}; + case EdgePosition::RIGHT: + return {.x_start = rectangle.x_max, + .y_start = rectangle.y_min, + .size = rectangle.SizeY()}; + } + } + + template + friend H AbslHashValue(H h, const Edge& e) { + return H::combine(std::move(h), e.x_start, e.y_start, e.size); + } + + bool operator==(const Edge& other) const { + return x_start == other.x_start && y_start == other.y_start && + size == other.size; + } +}; +} // namespace + +bool ReduceNumberofBoxes(std::vector* mandatory_rectangles, + std::vector* optional_rectangles) { + // The current implementation just greedly merge rectangles that shares an + // edge. This is far from optimal, and it exists a polynomial optimal + // algorithm (see page 3 of [1]) for this problem at least for the case where + // optional_rectangles is empty. + // + // TODO(user): improve + // + // [1] Eppstein, David. "Graph-theoretic solutions to computational geometry + // problems." International Workshop on Graph-Theoretic Concepts in Computer + // Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2009. + std::vector> rectangle_storage; + enum class OptionalEnum { OPTIONAL, MANDATORY }; + // bool for is_optional + std::vector> rectangles_ptr; + absl::flat_hash_map top_edges_to_rectangle; + absl::flat_hash_map bottom_edges_to_rectangle; + absl::flat_hash_map left_edges_to_rectangle; + absl::flat_hash_map right_edges_to_rectangle; + + using EdgePosition = Edge::EdgePosition; + + bool changed_optional = false; + bool changed_mandatory = false; + + auto add_rectangle = [&](const Rectangle* rectangle_ptr, + OptionalEnum optional) { + const int index = rectangles_ptr.size(); + rectangles_ptr.push_back({rectangle_ptr, optional}); + const Rectangle& rectangle = *rectangles_ptr[index].first; + top_edges_to_rectangle[Edge::GetEdge(rectangle, EdgePosition::TOP)] = index; + bottom_edges_to_rectangle[Edge::GetEdge(rectangle, EdgePosition::BOTTOM)] = + index; + left_edges_to_rectangle[Edge::GetEdge(rectangle, EdgePosition::LEFT)] = + index; + right_edges_to_rectangle[Edge::GetEdge(rectangle, EdgePosition::RIGHT)] = + index; + }; + for (const Rectangle& rectangle : *mandatory_rectangles) { + add_rectangle(&rectangle, OptionalEnum::MANDATORY); + } + for (const Rectangle& rectangle : *optional_rectangles) { + add_rectangle(&rectangle, OptionalEnum::OPTIONAL); + } + + auto remove_rectangle = [&](const int index) { + const Rectangle& rectangle = *rectangles_ptr[index].first; + const Edge top_edge = Edge::GetEdge(rectangle, EdgePosition::TOP); + const Edge bottom_edge = Edge::GetEdge(rectangle, EdgePosition::BOTTOM); + const Edge left_edge = Edge::GetEdge(rectangle, EdgePosition::LEFT); + const Edge right_edge = Edge::GetEdge(rectangle, EdgePosition::RIGHT); + top_edges_to_rectangle.erase(top_edge); + bottom_edges_to_rectangle.erase(bottom_edge); + left_edges_to_rectangle.erase(left_edge); + right_edges_to_rectangle.erase(right_edge); + rectangles_ptr[index].first = nullptr; + }; + + bool iteration_did_merge = true; + while (iteration_did_merge) { + iteration_did_merge = false; + for (int i = 0; i < rectangles_ptr.size(); ++i) { + if (!rectangles_ptr[i].first) { + continue; + } + const Rectangle& rectangle = *rectangles_ptr[i].first; + + const Edge top_edge = Edge::GetEdge(rectangle, EdgePosition::TOP); + const Edge bottom_edge = Edge::GetEdge(rectangle, EdgePosition::BOTTOM); + const Edge left_edge = Edge::GetEdge(rectangle, EdgePosition::LEFT); + const Edge right_edge = Edge::GetEdge(rectangle, EdgePosition::RIGHT); + + int index = -1; + if (const auto it = right_edges_to_rectangle.find(left_edge); + it != right_edges_to_rectangle.end()) { + index = it->second; + } else if (const auto it = left_edges_to_rectangle.find(right_edge); + it != left_edges_to_rectangle.end()) { + index = it->second; + } else if (const auto it = bottom_edges_to_rectangle.find(top_edge); + it != bottom_edges_to_rectangle.end()) { + index = it->second; + } else if (const auto it = top_edges_to_rectangle.find(bottom_edge); + it != top_edges_to_rectangle.end()) { + index = it->second; + } + if (index == -1) { + continue; + } + iteration_did_merge = true; + + // Merge two rectangles! + const OptionalEnum new_optional = + (rectangles_ptr[i].second == OptionalEnum::MANDATORY || + rectangles_ptr[index].second == OptionalEnum::MANDATORY) + ? OptionalEnum::MANDATORY + : OptionalEnum::OPTIONAL; + changed_mandatory = + changed_mandatory || (new_optional == OptionalEnum::MANDATORY); + changed_optional = + changed_optional || + (rectangles_ptr[i].second == OptionalEnum::OPTIONAL || + rectangles_ptr[index].second == OptionalEnum::OPTIONAL); + rectangle_storage.push_back(std::make_unique(rectangle)); + Rectangle& new_rectangle = *rectangle_storage.back(); + new_rectangle.GrowToInclude(*rectangles_ptr[index].first); + remove_rectangle(i); + remove_rectangle(index); + add_rectangle(&new_rectangle, new_optional); + } + } + + if (changed_mandatory) { + std::vector new_rectangles; + for (auto [rectangle, optional] : rectangles_ptr) { + if (rectangle && optional == OptionalEnum::MANDATORY) { + new_rectangles.push_back(*rectangle); + } + } + *mandatory_rectangles = std::move(new_rectangles); + } + if (changed_optional) { + std::vector new_rectangles; + for (auto [rectangle, optional] : rectangles_ptr) { + if (rectangle && optional == OptionalEnum::OPTIONAL) { + new_rectangles.push_back(*rectangle); + } + } + *optional_rectangles = std::move(new_rectangles); + } + return changed_mandatory; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/2d_rectangle_presolve.h b/ortools/sat/2d_rectangle_presolve.h new file mode 100644 index 00000000000..d5cefb9c26b --- /dev/null +++ b/ortools/sat/2d_rectangle_presolve.h @@ -0,0 +1,50 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_2D_RECTANGLE_PRESOLVE_H_ +#define OR_TOOLS_SAT_2D_RECTANGLE_PRESOLVE_H_ + +#include + +#include "absl/types/span.h" +#include "ortools/sat/diffn_util.h" + +namespace operations_research { +namespace sat { + +// Given a set of fixed boxes and a set of boxes that are not yet +// fixed (but attributed a range), look for a more optimal set of fixed +// boxes that are equivalent to the initial set of fixed boxes. This +// uses "equivalent" in the sense that a placement of the non-fixed boxes will +// be non-overlapping with all other boxes if and only if it was with the +// original set of fixed boxes too. +bool PresolveFixed2dRectangles( + absl::Span non_fixed_boxes, + std::vector* fixed_boxes); + +// Given a set of non-overlapping rectangles split in two groups, mandatory and +// optional, try to build a set of as few non-overlapping rectangles as +// possible defining a region R that satisfy: +// - R \subset (mandatory \union optional); +// - mandatory \subset R. +// +// The function updates the set of `mandatory_rectangles` with `R` and +// `optional_rectangles` with `optional_rectangles \setdiff R`. It returns +// true if the `mandatory_rectangles` was updated. +bool ReduceNumberofBoxes(std::vector* mandatory_rectangles, + std::vector* optional_rectangles); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_2D_RECTANGLE_PRESOLVE_H_ diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 812cd9ab3da..004f357360c 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -667,6 +667,7 @@ cc_library( ], hdrs = ["cp_model_presolve.h"], deps = [ + ":2d_rectangle_presolve", ":circuit", ":clause", ":cp_model_cc_proto", @@ -1145,12 +1146,10 @@ cc_library( "//ortools/util:strong_integers", "//ortools/util:time_limit", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) @@ -1901,7 +1900,6 @@ cc_library( ":sat_base", ":sat_parameters_cc_proto", "//ortools/base", - "//ortools/base:mathlimits", "//ortools/base:mathutil", "//ortools/base:stl_util", "//ortools/util:random_engine", @@ -1909,6 +1907,8 @@ cc_library( "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -2036,6 +2036,21 @@ cc_library( ], ) +cc_library( + name = "2d_rectangle_presolve", + srcs = ["2d_rectangle_presolve.cc"], + hdrs = ["2d_rectangle_presolve.h"], + deps = [ + ":diffn_util", + ":integer", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "2d_orthogonal_packing_testing", testonly = 1, diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 510b38d6228..1b952d62bb7 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -208,7 +208,8 @@ bool ClauseManager::Propagate(Trail* trail) { } absl::Span ClauseManager::Reason(const Trail& /*trail*/, - int trail_index) const { + int trail_index, + int64_t /*conflict_id*/) const { return reasons_[trail_index]->PropagationReason(); } @@ -849,7 +850,7 @@ bool BinaryImplicationGraph::Propagate(Trail* trail) { } absl::Span BinaryImplicationGraph::Reason( - const Trail& /*trail*/, int trail_index) const { + const Trail& /*trail*/, int trail_index, int64_t /*conflict_id*/) const { return {&reasons_[trail_index], 1}; } diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index e49f1ffe3b6..4a30915cdce 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -171,8 +171,8 @@ class ClauseManager : public SatPropagator { // SatPropagator API. bool Propagate(Trail* trail) final; - absl::Span Reason(const Trail& trail, - int trail_index) const final; + absl::Span Reason(const Trail& trail, int trail_index, + int64_t conflict_id) const final; // Returns the reason of the variable at given trail_index. This only works // for variable propagated by this class and is almost the same as Reason() @@ -504,8 +504,8 @@ class BinaryImplicationGraph : public SatPropagator { // SatPropagator interface. bool Propagate(Trail* trail) final; - absl::Span Reason(const Trail& trail, - int trail_index) const final; + absl::Span Reason(const Trail& trail, int trail_index, + int64_t conflict_id) const final; // Resizes the data structure. void Resize(int num_variables); diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index 79f037a3bea..918a771ea0d 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -83,12 +83,41 @@ GreaterThanAtLeastOneOfPropagator::GreaterThanAtLeastOneOfPropagator( const absl::Span selectors, const absl::Span enforcements, Model* model) : target_var_(target_var), - exprs_(exprs.begin(), exprs.end()), - selectors_(selectors.begin(), selectors.end()), enforcements_(enforcements.begin(), enforcements.end()), + selectors_(selectors.begin(), selectors.end()), + exprs_(exprs.begin(), exprs.end()), trail_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()) {} +void GreaterThanAtLeastOneOfPropagator::Explain( + int id, IntegerValue propagation_slack, IntegerVariable /*var_to_explain*/, + int /*trail_index*/, std::vector* literals_reason, + std::vector* trail_indices_reason) { + literals_reason->clear(); + trail_indices_reason->clear(); + + const int first_non_false = id; + const IntegerValue target_min = propagation_slack; + + for (const Literal l : enforcements_) { + literals_reason->push_back(l.Negated()); + } + for (int i = 0; i < first_non_false; ++i) { + // If the level zero bounds is good enough, no reason needed. + // + // TODO(user): We could also skip this if we already have the reason for + // the expression being high enough in the current conflict. + if (integer_trail_->LevelZeroLowerBound(exprs_[i]) >= target_min) { + continue; + } + + literals_reason->push_back(selectors_[i]); + } + integer_trail_->AddAllGreaterThanConstantReason( + absl::MakeSpan(exprs_).subspan(first_non_false), target_min, + trail_indices_reason); +} + bool GreaterThanAtLeastOneOfPropagator::Propagate() { // TODO(user): In case of a conflict, we could push one of them to false if // it is the only one. @@ -101,41 +130,42 @@ bool GreaterThanAtLeastOneOfPropagator::Propagate() { // Propagate() calls. IntegerValue target_min = kMaxIntegerValue; const IntegerValue current_min = integer_trail_->LowerBound(target_var_); - for (int i = 0; i < exprs_.size(); ++i) { - if (trail_->Assignment().LiteralIsTrue(selectors_[i])) return true; - if (trail_->Assignment().LiteralIsFalse(selectors_[i])) continue; - target_min = std::min(target_min, integer_trail_->LowerBound(exprs_[i])); + const AssignmentView assignment(trail_->Assignment()); + + int first_non_false = 0; + const int size = exprs_.size(); + for (int i = 0; i < size; ++i) { + if (assignment.LiteralIsTrue(selectors_[i])) return true; + + // The permutation is needed to have proper lazy reason. + if (assignment.LiteralIsFalse(selectors_[i])) { + if (i != first_non_false) { + std::swap(selectors_[i], selectors_[first_non_false]); + std::swap(exprs_[i], exprs_[first_non_false]); + } + ++first_non_false; + continue; + } + + const IntegerValue min = integer_trail_->LowerBound(exprs_[i]); + if (min < target_min) { + target_min = min; - // Abort if we can't get a better bound. - if (target_min <= current_min) return true; + // Abort if we can't get a better bound. + if (target_min <= current_min) return true; + } } + if (target_min == kMaxIntegerValue) { // All false, conflit. *(trail_->MutableConflict()) = selectors_; return false; } - literal_reason_.clear(); - integer_reason_.clear(); - for (const Literal l : enforcements_) { - literal_reason_.push_back(l.Negated()); - } - for (int i = 0; i < exprs_.size(); ++i) { - // If the level zero bounds is good enough, no reason needed. - if (integer_trail_->LevelZeroLowerBound(exprs_[i]) >= target_min) { - continue; - } - if (trail_->Assignment().LiteralIsFalse(selectors_[i])) { - literal_reason_.push_back(selectors_[i]); - } else { - if (!exprs_[i].IsConstant()) { - integer_reason_.push_back(exprs_[i].GreaterOrEqual(target_min)); - } - } - } - return integer_trail_->Enqueue( - IntegerLiteral::GreaterOrEqual(target_var_, target_min), literal_reason_, - integer_reason_); + // Note that we use id/propagation_slack for other purpose. + return integer_trail_->EnqueueWithLazyReason( + IntegerLiteral::GreaterOrEqual(target_var_, target_min), + /*id=*/first_non_false, /*propagation_slack=*/target_min, this); } void GreaterThanAtLeastOneOfPropagator::RegisterWith( diff --git a/ortools/sat/cp_constraints.h b/ortools/sat/cp_constraints.h index e6612f40d9e..67a3a5fad67 100644 --- a/ortools/sat/cp_constraints.h +++ b/ortools/sat/cp_constraints.h @@ -71,7 +71,8 @@ class BooleanXorPropagator : public PropagatorInterface { // This constraint take care of this case when no selectors[i] is chosen yet. // // This constraint support duplicate selectors. -class GreaterThanAtLeastOneOfPropagator : public PropagatorInterface { +class GreaterThanAtLeastOneOfPropagator : public PropagatorInterface, + public LazyReasonInterface { public: GreaterThanAtLeastOneOfPropagator(IntegerVariable target_var, absl::Span exprs, @@ -88,17 +89,22 @@ class GreaterThanAtLeastOneOfPropagator : public PropagatorInterface { bool Propagate() final; void RegisterWith(GenericLiteralWatcher* watcher); + // For LazyReasonInterface. + void Explain(int id, IntegerValue propagation_slack, + IntegerVariable var_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) final; + private: const IntegerVariable target_var_; - const std::vector exprs_; - const std::vector selectors_; const std::vector enforcements_; + // Non-const as we swap elements around. + std::vector selectors_; + std::vector exprs_; + Trail* trail_; IntegerTrail* integer_trail_; - - std::vector literal_reason_; - std::vector integer_reason_; }; // ============================================================================ diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index e3d9b14dcaf..fe6309fd485 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -41,6 +41,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "google/protobuf/repeated_field.h" #include "google/protobuf/text_format.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" @@ -51,6 +52,7 @@ #include "ortools/graph/strongly_connected_components.h" #include "ortools/graph/topologicalsorter.h" #include "ortools/port/proto_utils.h" +#include "ortools/sat/2d_rectangle_presolve.h" #include "ortools/sat/circuit.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model.pb.h" @@ -2401,6 +2403,17 @@ bool CpModelPresolver::AddVarAffineRepresentativeFromLinearEquality( return CanonicalizeLinear(ct); } +namespace { + +bool IsLinearEqualityConstraint(const ConstraintProto& ct) { + return ct.constraint_case() == ConstraintProto::kLinear && + ct.linear().domain().size() == 2 && + ct.linear().domain(0) == ct.linear().domain(1) && + ct.enforcement_literal().empty(); +} + +} // namespace + // Any equality must be true modulo n. // // If the gcd of all but one term is not one, we can rewrite the last term using @@ -2414,10 +2427,7 @@ bool CpModelPresolver::AddVarAffineRepresentativeFromLinearEquality( // problem to two problem of half size. So at least we can do it in O(n log n). bool CpModelPresolver::PresolveLinearEqualityWithModulo(ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; - if (ct->constraint_case() != ConstraintProto::kLinear) return false; - if (ct->linear().domain().size() != 2) return false; - if (ct->linear().domain(0) != ct->linear().domain(1)) return false; - if (!ct->enforcement_literal().empty()) return false; + if (!IsLinearEqualityConstraint(*ct)) return false; const int num_variables = ct->linear().vars().size(); if (num_variables < 2) return false; @@ -5633,11 +5643,14 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { bool x_constant = true; bool y_constant = true; bool has_zero_sized_interval = false; + bool has_potential_zero_sized_interval = false; // Filter absent boxes. int new_size = 0; - std::vector bounding_boxes; + std::vector bounding_boxes, fixed_boxes; + std::vector non_fixed_boxes; std::vector active_boxes; + absl::flat_hash_set fixed_item_indexes; for (int i = 0; i < proto.x_intervals_size(); ++i) { const int x_interval_index = proto.x_intervals(i); const int y_interval_index = proto.y_intervals(i); @@ -5655,6 +5668,19 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { IntegerValue(context_->StartMin(y_interval_index)), IntegerValue(context_->EndMax(y_interval_index))}); active_boxes.push_back(new_size); + if (context_->IntervalIsConstant(x_interval_index) && + context_->IntervalIsConstant(y_interval_index) && + context_->SizeMax(x_interval_index) > 0 && + context_->SizeMax(y_interval_index) > 0) { + fixed_boxes.push_back(bounding_boxes.back()); + fixed_item_indexes.insert(new_size); + } else { + non_fixed_boxes.push_back( + {.box_index = new_size, + .bounding_area = bounding_boxes.back(), + .x_size = context_->SizeMin(x_interval_index), + .y_size = context_->SizeMin(y_interval_index)}); + } new_size++; if (x_constant && !context_->IntervalIsConstant(x_interval_index)) { @@ -5667,6 +5693,10 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { context_->SizeMax(y_interval_index) == 0) { has_zero_sized_interval = true; } + if (context_->SizeMin(x_interval_index) == 0 || + context_->SizeMin(y_interval_index) == 0) { + has_potential_zero_sized_interval = true; + } } std::vector> components = GetOverlappingRectangleComponents( @@ -5736,6 +5766,69 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { return RemoveConstraint(ct); } + // We check if the fixed boxes are not overlapping so downstream code can + // assume it to be true. + for (int i = 0; i < fixed_boxes.size(); ++i) { + const Rectangle& fixed_box = fixed_boxes[i]; + for (int j = i + 1; j < fixed_boxes.size(); ++j) { + const Rectangle& other_fixed_box = fixed_boxes[j]; + if (!fixed_box.IsDisjoint(other_fixed_box)) { + return context_->NotifyThatModelIsUnsat( + "Two fixed boxes in no_overlap_2d overlap"); + } + } + } + + if (fixed_boxes.size() == active_boxes.size()) { + context_->UpdateRuleStats("no_overlap_2d: all boxes are fixed"); + return RemoveConstraint(ct); + } + + // TODO(user): presolve the zero-size fixed items so they are disjoint from + // the other fixed items. Then the following presolve is still valid. On the + // other hand, we cannot do much with non-fixed zero-size items. + if (!has_potential_zero_sized_interval && !fixed_boxes.empty()) { + const bool presolved = + PresolveFixed2dRectangles(non_fixed_boxes, &fixed_boxes); + if (presolved) { + NoOverlap2DConstraintProto new_no_overlap_2d; + + // Replace the old fixed intervals by the new ones. + const int old_size = proto.x_intervals_size(); + for (int i = 0; i < old_size; ++i) { + if (fixed_item_indexes.contains(i)) { + continue; + } + new_no_overlap_2d.add_x_intervals(proto.x_intervals(i)); + new_no_overlap_2d.add_y_intervals(proto.y_intervals(i)); + } + for (const Rectangle& fixed_box : fixed_boxes) { + const int item_x_interval = + context_->working_model->constraints().size(); + IntervalConstraintProto* new_interval = + context_->working_model->add_constraints()->mutable_interval(); + new_interval->mutable_start()->set_offset(fixed_box.x_min.value()); + new_interval->mutable_size()->set_offset(fixed_box.SizeX().value()); + new_interval->mutable_end()->set_offset(fixed_box.x_max.value()); + + const int item_y_interval = + context_->working_model->constraints().size(); + new_interval = + context_->working_model->add_constraints()->mutable_interval(); + new_interval->mutable_start()->set_offset(fixed_box.y_min.value()); + new_interval->mutable_size()->set_offset(fixed_box.SizeY().value()); + new_interval->mutable_end()->set_offset(fixed_box.y_max.value()); + + new_no_overlap_2d.add_x_intervals(item_x_interval); + new_no_overlap_2d.add_y_intervals(item_y_interval); + } + context_->working_model->add_constraints()->mutable_no_overlap_2d()->Swap( + &new_no_overlap_2d); + context_->UpdateNewConstraintsVariableUsage(); + context_->UpdateRuleStats("no_overlap_2d: presolved fixed rectangles"); + return RemoveConstraint(ct); + } + } return new_size < initial_num_boxes; } @@ -7516,6 +7609,7 @@ void CpModelPresolver::ShiftObjectiveWithExactlyOnes() { // This assumes we are more or less at the propagation fix point, even if we // try to address cases where we are not. void CpModelPresolver::ExpandObjective() { + if (time_limit_->LimitReached()) return; if (context_->ModelIsUnsat()) return; PresolveTimer timer(__FUNCTION__, logger_, time_limit_); @@ -7558,6 +7652,15 @@ void CpModelPresolver::ExpandObjective() { // Deal with exactly one. // An exactly one is always tight on the upper bound of one term. + // + // Note(user): This code assume there is no fixed variable in the exactly + // one. We thus make sure the constraint is re-presolved if for some reason + // we didn't reach the fixed point before calling this code. + if (ct.constraint_case() == ConstraintProto::kExactlyOne) { + if (PresolveExactlyOne(context_->working_model->mutable_constraints(c))) { + context_->UpdateConstraintVariableUsage(c); + } + } if (ct.constraint_case() == ConstraintProto::kExactlyOne) { const int num_terms = ct.exactly_one().literals().size(); ++num_tight_constraints; @@ -7583,11 +7686,7 @@ void CpModelPresolver::ExpandObjective() { } // Skip everything that is not a linear equality constraint. - if (ct.constraint_case() != ConstraintProto::kLinear || - ct.linear().domain().size() != 2 || - ct.linear().domain(0) != ct.linear().domain(1)) { - continue; - } + if (!IsLinearEqualityConstraint(ct)) continue; // Let see for which variable is it "tight". We need a coeff of 1, and that // the implied bounds match exactly. @@ -9105,11 +9204,13 @@ void CpModelPresolver::DetectDifferentVariables() { continue; } + const int lit1 = ct1.enforcement_literal(0); + const int lit2 = ct2.enforcement_literal(0); + // Detect x != y via lit => x > y && not(lit) => x < y. if (ct1.linear().vars().size() == 2 && ct1.linear().coeffs(0) == -ct1.linear().coeffs(1) && - ct1.enforcement_literal(0) == - NegatedRef(ct2.enforcement_literal(0))) { + lit1 == NegatedRef(lit2)) { // We have x - y in domain1 or in domain2, so it must be in the union. Domain union_of_domain = ReadDomainFromProto(ct1.linear()) @@ -9125,9 +9226,10 @@ void CpModelPresolver::DetectDifferentVariables() { } } - context_->UpdateRuleStats("incompatible linear: add implication"); - context_->AddImplication(ct1.enforcement_literal(0), - NegatedRef(ct2.enforcement_literal(0))); + if (lit1 != NegatedRef(lit2)) { + context_->UpdateRuleStats("incompatible linear: add implication"); + context_->AddImplication(lit1, NegatedRef(lit2)); + } } } } @@ -9593,10 +9695,10 @@ bool CpModelPresolver::RemoveCommonPart( int definiting_equation = -1; for (const auto [c, multiple] : block) { const ConstraintProto& ct = context_->working_model->constraints(c); - if (ct.linear().vars().size() != common_var_coeff_map.size() + 1) continue; - if (ct.linear().domain(0) != ct.linear().domain(1)) continue; - if (!ct.enforcement_literal().empty()) continue; if (std::abs(multiple) != 1) continue; + if (!IsLinearEqualityConstraint(ct)) continue; + if (ct.linear().vars().size() != common_var_coeff_map.size() + 1) continue; + context_->UpdateRuleStats( "linear matrix: defining equation for common rectangle"); definiting_equation = c; @@ -10385,11 +10487,8 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { const int num_constraints = context_->working_model->constraints_size(); for (int c = 0; c < num_constraints; ++c) { const ConstraintProto& ct = context_->working_model->constraints(c); - if (ct.constraint_case() != ConstraintProto::kLinear) continue; - if (!ct.enforcement_literal().empty()) continue; + if (!IsLinearEqualityConstraint(ct)) continue; if (ct.linear().vars().size() <= 2) continue; - if (ct.linear().domain().size() != 2) continue; - if (ct.linear().domain(0) != ct.linear().domain(1)) continue; // Our canonicalization should sort constraints, we skip non-canonical ones. if (!std::is_sorted(ct.linear().vars().begin(), ct.linear().vars().end())) { @@ -10544,9 +10643,7 @@ void CpModelPresolver::ExtractEncodingFromLinear() { } case ConstraintProto::kLinear: { // We only consider equality with no enforcement. - if (!ct.enforcement_literal().empty()) continue; - if (ct.linear().domain().size() != 2) continue; - if (ct.linear().domain(0) != ct.linear().domain(1)) continue; + if (!IsLinearEqualityConstraint(ct)) continue; // We also want a single non-Boolean. // Note that this assume the constraint is canonicalized. @@ -11435,6 +11532,7 @@ bool CpModelPresolver::ProcessChangedVariables(std::vector* in_queue, } void CpModelPresolver::PresolveToFixPoint() { + if (time_limit_->LimitReached()) return; if (context_->ModelIsUnsat()) return; PresolveTimer timer(__FUNCTION__, logger_, time_limit_); diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 371f2fa482b..993acaf5194 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -24,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/flags/flag.h" #include "absl/log/check.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" @@ -644,14 +645,6 @@ absl::flat_hash_map GetNamedParameters( new_params.set_use_dynamic_precedence_in_disjunctive(false); new_params.set_use_dynamic_precedence_in_cumulative(false); strategies["fixed"] = new_params; - - new_params.set_linearization_level(0); - strategies["fixed_no_lp"] = new_params; - - new_params.set_linearization_level(2); - new_params.set_add_lp_constraints_lazily(false); - new_params.set_root_lp_iterations(100'000); - strategies["fixed_max_lp"] = new_params; } // Quick restart. diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 091309f2992..6a8c52ad46d 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -609,9 +609,9 @@ std::string CpSolverResponseStats(const CpSolverResponse& response, namespace { -void LogSubsolverNames( - const std::vector>& subsolvers, - absl::Span ignored, SolverLogger* logger) { +void LogSubsolverNames(absl::Span> subsolvers, + absl::Span ignored, + SolverLogger* logger) { if (!logger->LoggingIsEnabled()) return; std::vector full_problem_solver_names; diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index 981da952e0f..0cb6fbb67c0 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -49,6 +50,53 @@ bool Rectangle::IsDisjoint(const Rectangle& other) const { other.y_min >= y_max; } +absl::InlinedVector Rectangle::SetDifference( + const Rectangle& other) const { + const Rectangle intersect = Intersect(other); + if (intersect.SizeX() == 0) { + return {*this}; + } + + //------------------- + //| | 4 | | + //| |---------| | + //| 1 | other | 2 | + //| |---------| | + //| | 3 | | + //------------------- + absl::InlinedVector result; + if (x_min < intersect.x_min) { + // Piece 1 + result.push_back({.x_min = x_min, + .x_max = intersect.x_min, + .y_min = y_min, + .y_max = y_max}); + } + if (x_max > intersect.x_max) { + // Piece 2 + result.push_back({.x_min = intersect.x_max, + .x_max = x_max, + .y_min = y_min, + .y_max = y_max}); + } + if (y_min < intersect.y_min) { + // Piece 3 + result.push_back({.x_min = intersect.x_min, + .x_max = intersect.x_max, + .y_min = y_min, + .y_max = intersect.y_min}); + } + if (y_max > intersect.y_max) { + // Piece 4 + result.push_back({.x_min = intersect.x_min, + .x_max = intersect.x_max, + .y_min = intersect.y_max, + .y_max = y_max}); + } + + return result; +} + std::vector> GetOverlappingRectangleComponents( absl::Span rectangles, absl::Span active_rectangles) { if (active_rectangles.empty()) return {}; @@ -160,7 +208,7 @@ bool BoxesAreInEnergyConflict(const std::vector& rectangles, for (int k = 0; k < i; ++k) { const int task_index = boxes_by_increasing_y_max[k].task_index; if (rectangles[task_index].y_min >= y_starts[j]) { - conflict->TakeUnionWith(rectangles[task_index]); + conflict->GrowToInclude(rectangles[task_index]); } } } @@ -280,7 +328,7 @@ bool AnalyzeIntervals(bool transpose, absl::Span local_boxes, ? rectangles[task_index].y_min : rectangles[task_index].x_min; if (task_x_min < starts[j]) continue; - conflict->TakeUnionWith(rectangles[task_index]); + conflict->GrowToInclude(rectangles[task_index]); } } return false; @@ -1490,18 +1538,20 @@ FindRectanglesResult FindRectanglesWithEnergyConflictMC( return result; } -std::string RenderDot(std::pair bb_sizes, +std::string RenderDot(std::optional bb, absl::Span solution) { const std::vector colors = {"red", "green", "blue", "cyan", "yellow", "purple"}; std::stringstream ss; ss << "digraph {\n"; - ss << " graph [ bgcolor=lightgray width=" << 2 * bb_sizes.first - << " height=" << 2 * bb_sizes.second << "]\n"; + ss << " graph [ bgcolor=lightgray ]\n"; ss << " node [style=filled]\n"; - ss << " bb [fillcolor=\"grey\" pos=\"" << bb_sizes.first << "," - << bb_sizes.second << "!\" shape=box width=" << 2 * bb_sizes.first - << " height=" << 2 * bb_sizes.second << "]\n"; + if (bb.has_value()) { + ss << " bb [fillcolor=\"grey\" pos=\"" << 2 * bb->x_min + bb->SizeX() + << "," << 2 * bb->y_min + bb->SizeY() + << "!\" shape=box width=" << 2 * bb->SizeX() + << " height=" << 2 * bb->SizeY() << "]\n"; + } for (int i = 0; i < solution.size(); ++i) { ss << " " << i << " [fillcolor=\"" << colors[i % colors.size()] << "\" pos=\"" << 2 * solution[i].x_min + solution[i].SizeX() << "," @@ -1513,5 +1563,30 @@ std::string RenderDot(std::pair bb_sizes, return ss.str(); } +std::vector FindEmptySpaces( + const Rectangle& bounding_box, std::vector ocupied_rectangles) { + std::vector empty_spaces = {bounding_box}; + std::vector new_empty_spaces; + // Sorting is not necessary for correctness but makes it faster. + std::sort(ocupied_rectangles.begin(), ocupied_rectangles.end(), + [](const Rectangle& a, const Rectangle& b) { + return std::tuple(a.x_min, -a.x_max, a.y_min) < + std::tuple(b.x_min, -b.x_max, b.y_min); + }); + for (const Rectangle& ocupied_rectangle : ocupied_rectangles) { + new_empty_spaces.clear(); + for (const auto& empty_space : empty_spaces) { + for (Rectangle& r : empty_space.SetDifference(ocupied_rectangle)) { + new_empty_spaces.push_back(std::move(r)); + } + } + empty_spaces.swap(new_empty_spaces); + if (empty_spaces.empty()) { + break; + } + } + return empty_spaces; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index cc5f4faae2d..0c3eac4c626 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -25,6 +25,7 @@ #include #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/random/bit_gen_ref.h" #include "absl/strings/str_format.h" @@ -42,7 +43,7 @@ struct Rectangle { IntegerValue y_min; IntegerValue y_max; - void TakeUnionWith(const Rectangle& other) { + void GrowToInclude(const Rectangle& other) { x_min = std::min(x_min, other.x_min); y_min = std::min(y_min, other.y_min); x_max = std::max(x_max, other.x_max); @@ -60,6 +61,10 @@ struct Rectangle { Rectangle Intersect(const Rectangle& other) const; IntegerValue IntersectArea(const Rectangle& other) const; + // Returns `this \ other` as a set of disjoint rectangles of non-empty area. + // The resulting vector will have at most four elements. + absl::InlinedVector SetDifference(const Rectangle& other) const; + template friend void AbslStringify(Sink& sink, const Rectangle& r) { absl::Format(&sink, "rectangle(x(%i..%i), y(%i..%i))", r.x_min.value(), @@ -399,6 +404,8 @@ struct RectangleInRange { } } + Rectangle GetBoudingBox() const { return bounding_area; } + // Returns an empty rectangle if it is possible for no intersection to happen. Rectangle GetMinimumIntersection(const Rectangle& containing_area) const { IntegerValue smallest_area = std::numeric_limits::max(); @@ -591,9 +598,15 @@ FindRectanglesResult FindRectanglesWithEnergyConflictMC( // Render a packing solution as a Graphviz dot file. Only works in the "neato" // or "fdp" Graphviz backends. -std::string RenderDot(std::pair bb_sizes, +std::string RenderDot(std::optional bb, absl::Span solution); +// Given a bounding box and a list of rectangles inside that bounding box, +// returns a list of rectangles partitioning the empty area inside the bounding +// box. +std::vector FindEmptySpaces( + const Rectangle& bounding_box, std::vector ocupied_rectangles); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 534cdf18787..ea6b4885af0 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -711,6 +711,7 @@ bool IntegerTrail::Propagate(Trail* trail) { // be empty. if (level > integer_search_levels_.size()) { integer_search_levels_.push_back(integer_trail_.size()); + lazy_reason_decision_levels_.push_back(lazy_reasons_.size()); reason_decision_levels_.push_back(literals_reason_starts_.size()); CHECK_EQ(level, integer_search_levels_.size()); } @@ -791,6 +792,10 @@ void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) { } integer_trail_.resize(target); + // Resize lazy reason. + lazy_reasons_.resize(lazy_reason_decision_levels_[level]); + lazy_reason_decision_levels_.resize(level); + // Clear reason. const int old_size = reason_decision_levels_[level]; reason_decision_levels_.resize(level); @@ -805,6 +810,7 @@ void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) { literals_reason_starts_.resize(old_size); bounds_reason_starts_.resize(old_size); + cached_sizes_.resize(old_size); } // We notify the new level once all variables have been restored to their @@ -824,6 +830,8 @@ void IntegerTrail::ReserveSpaceForNumVariables(int num_vars) { integer_trail_.reserve(size); var_trail_index_cache_.reserve(size); tmp_var_to_trail_index_in_queue_.reserve(size); + + var_to_trail_index_at_lower_level_.reserve(size); } IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, @@ -849,6 +857,7 @@ IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, var_trail_index_cache_.resize(var_lbs_.size(), integer_trail_.size()); tmp_var_to_trail_index_in_queue_.resize(var_lbs_.size(), 0); + var_to_trail_index_at_lower_level_.resize(var_lbs_.size(), 0); for (SparseBitset* w : watchers_) { w->Resize(NumIntegerVariables()); @@ -939,10 +948,19 @@ int IntegerTrail::FindTrailIndexOfVarBefore(IntegerVariable var, // Optimization. We assume this is only called when computing a reason, so we // can ignore this trail_index if we already need a more restrictive reason // for this var. + // + // Hacky: We know this is only called with threshold == trail_index of the + // trail entry we are trying to explain. So this test can only trigger when a + // variable was shown to be already implied by the current conflict. const int index_in_queue = tmp_var_to_trail_index_in_queue_[var]; if (threshold <= index_in_queue) { - if (index_in_queue != std::numeric_limits::max()) - has_dependency_ = true; + // Disable the other optim if we might expand this literal during + // 1-UIP resolution. + const int last_decision_index = + integer_search_levels_.empty() ? 0 : integer_search_levels_.back(); + if (index_in_queue >= last_decision_index) { + info_is_valid_on_subsequent_last_level_expansion_ = false; + } return -1; } @@ -1157,9 +1175,7 @@ std::vector* IntegerTrail::InitializeConflict( if (use_lazy_reason) { // We use the current trail index here. conflict->clear(); - const int trail_index = integer_trail_.size(); - lazy_reasons_[trail_index].Explain(integer_literal, trail_index, conflict, - &tmp_queue_); + lazy_reasons_.back().Explain(conflict, &tmp_queue_); } else { conflict->assign(literals_reason.begin(), literals_reason.end()); const int num_vars = var_lbs_.size(); @@ -1399,20 +1415,15 @@ void IntegerTrail::EnqueueLiteralInternal( } const int trail_index = trail_->Index(); - if (trail_index >= boolean_trail_index_to_integer_one_.size()) { - boolean_trail_index_to_integer_one_.resize(trail_index + 1); + if (trail_index >= boolean_trail_index_to_reason_index_.size()) { + boolean_trail_index_to_reason_index_.resize(trail_index + 1); } - boolean_trail_index_to_integer_one_[trail_index] = integer_trail_.size(); const int reason_index = use_lazy_reason - ? -1 + ? -static_cast(lazy_reasons_.size()) : AppendReasonToInternalBuffers(literal_reason, integer_reason); - - integer_trail_.push_back({/*bound=*/IntegerValue(0), - /*var=*/kNoIntegerVariable, - /*prev_trail_index=*/-1, - /*reason_index=*/reason_index}); + boolean_trail_index_to_reason_index_[trail_index] = reason_index; trail_->Enqueue(literal, propagator_id_); } @@ -1494,6 +1505,7 @@ int IntegerTrail::AppendReasonToInternalBuffers( absl::Span integer_reason) { const int reason_index = literals_reason_starts_.size(); DCHECK_EQ(reason_index, bounds_reason_starts_.size()); + DCHECK_EQ(reason_index, cached_sizes_.size()); literals_reason_starts_.push_back(literals_reason_buffer_.size()); if (!literal_reason.empty()) { @@ -1502,6 +1514,7 @@ int IntegerTrail::AppendReasonToInternalBuffers( literal_reason.end()); } + cached_sizes_.push_back(-1); bounds_reason_starts_.push_back(bounds_reason_buffer_.size()); if (!integer_reason.empty()) { bounds_reason_buffer_.insert(bounds_reason_buffer_.end(), @@ -1511,6 +1524,10 @@ int IntegerTrail::AppendReasonToInternalBuffers( return reason_index; } +int64_t IntegerTrail::NextConflictId() { + return sat_solver_->num_failures() + 1; +} + bool IntegerTrail::EnqueueInternal( IntegerLiteral i_lit, bool use_lazy_reason, absl::Span literal_reason, @@ -1545,10 +1562,9 @@ bool IntegerTrail::EnqueueInternal( integer_reason); { const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason); - const int num_vars = var_lbs_.size(); // must be signed. - if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); + if (trail_index >= 0) tmp_queue_.push_back(trail_index); } - MergeReasonIntoInternal(conflict); + MergeReasonIntoInternal(conflict, NextConflictId()); return false; } @@ -1594,13 +1610,14 @@ bool IntegerTrail::EnqueueInternal( IntegerValue bound; const LiteralIndex literal_index = encoder_->SearchForLiteralAtOrBefore(i_lit, &bound); + int bool_index = -1; if (literal_index != kNoLiteralIndex) { const Literal to_enqueue = Literal(literal_index); if (trail_->Assignment().LiteralIsFalse(to_enqueue)) { auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason, integer_reason); conflict->push_back(to_enqueue); - MergeReasonIntoInternal(conflict); + MergeReasonIntoInternal(conflict, NextConflictId()); return false; } @@ -1624,12 +1641,7 @@ bool IntegerTrail::EnqueueInternal( // Subtle: the reason is the same as i_lit, that we will enqueue if no // conflict occur at position integer_trail_.size(), so we just refer to // this index here. - const int trail_index = trail_->Index(); - if (trail_index >= boolean_trail_index_to_integer_one_.size()) { - boolean_trail_index_to_integer_one_.resize(trail_index + 1); - } - boolean_trail_index_to_integer_one_[trail_index] = - integer_trail_.size(); + bool_index = trail_->Index(); trail_->Enqueue(to_enqueue, propagator_id_); } } @@ -1659,13 +1671,19 @@ bool IntegerTrail::EnqueueInternal( int reason_index; if (use_lazy_reason) { - reason_index = -1; + reason_index = -static_cast(lazy_reasons_.size()); } else if (trail_index_with_same_reason >= integer_trail_.size()) { reason_index = AppendReasonToInternalBuffers(literal_reason, integer_reason); } else { reason_index = integer_trail_[trail_index_with_same_reason].reason_index; } + if (bool_index >= 0) { + if (bool_index >= boolean_trail_index_to_reason_index_.size()) { + boolean_trail_index_to_reason_index_.resize(bool_index + 1); + } + boolean_trail_index_to_reason_index_[bool_index] = reason_index; + } const int prev_trail_index = var_trail_index_[i_lit.var]; integer_trail_.push_back({/*bound=*/i_lit.bound, @@ -1729,63 +1747,53 @@ bool IntegerTrail::EnqueueAssociatedIntegerLiteral(IntegerLiteral i_lit, return true; } -void IntegerTrail::ComputeLazyReasonIfNeeded(int trail_index) const { - const int reason_index = integer_trail_[trail_index].reason_index; - if (reason_index == -1) { - const TrailEntry& entry = integer_trail_[trail_index]; - const IntegerLiteral literal(entry.var, entry.bound); - lazy_reasons_[trail_index].Explain(literal, trail_index, - &lazy_reason_literals_, - &lazy_reason_trail_indices_); +void IntegerTrail::ComputeLazyReasonIfNeeded(int reason_index) const { + if (reason_index < 0) { + lazy_reasons_[-reason_index - 1].Explain(&lazy_reason_literals_, + &lazy_reason_trail_indices_); } } -absl::Span IntegerTrail::Dependencies(int trail_index) const { - const int reason_index = integer_trail_[trail_index].reason_index; - if (reason_index == -1) { +absl::Span IntegerTrail::Dependencies(int reason_index) const { + if (reason_index < 0) { return absl::Span(lazy_reason_trail_indices_); } + const int cached_size = cached_sizes_[reason_index]; + if (cached_size == 0) return {}; + const int start = bounds_reason_starts_[reason_index]; + if (cached_size > 0) { + return absl::MakeSpan(&trail_index_reason_buffer_[start], cached_size); + } + + // Else we cache. + DCHECK_EQ(cached_size, -1); const int end = reason_index + 1 < bounds_reason_starts_.size() ? bounds_reason_starts_[reason_index + 1] : bounds_reason_buffer_.size(); - if (start == end) return {}; - - // Cache the result if not already computed. Remark, if the result was never - // computed then the span trail_index_reason_buffer_[start, end) will either - // be non-existent or full of -1. - // - // TODO(user): For empty reason, we will always recompute them. if (end > trail_index_reason_buffer_.size()) { - trail_index_reason_buffer_.resize(end, -1); + trail_index_reason_buffer_.resize(end); } - if (trail_index_reason_buffer_[start] == -1) { - int new_end = start; - const int num_vars = var_lbs_.size(); - for (int i = start; i < end; ++i) { - const int dep = - FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]); - if (dep >= num_vars) { - trail_index_reason_buffer_[new_end++] = dep; - } + + int new_size = 0; + int* data = trail_index_reason_buffer_.data() + start; + const int num_vars = var_lbs_.size(); + for (int i = start; i < end; ++i) { + const int dep = + FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]); + if (dep >= num_vars) { + data[new_size++] = dep; } - return absl::Span(&trail_index_reason_buffer_[start], - new_end - start); - } else { - // TODO(user): We didn't store new_end in a previous call, so end might be - // larger. That is a bit annoying since we have to test for -1 while - // iterating. - return absl::Span(&trail_index_reason_buffer_[start], - end - start); } + cached_sizes_[reason_index] = new_size; + if (new_size == 0) return {}; + return absl::MakeSpan(data, new_size); } -void IntegerTrail::AppendLiteralsReason(int trail_index, +void IntegerTrail::AppendLiteralsReason(int reason_index, std::vector* output) const { - CHECK_GE(trail_index, var_lbs_.size()); - const int reason_index = integer_trail_[trail_index].reason_index; - if (reason_index == -1) { + if (reason_index < 0) { for (const Literal l : lazy_reason_literals_) { if (!added_variables_[l.Variable()]) { added_variables_.Set(l.Variable()); @@ -1826,17 +1834,35 @@ void IntegerTrail::MergeReasonInto(absl::Span literals, // Note that it is important for size to be signed because of -1 indices. if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); } - return MergeReasonIntoInternal(output); + return MergeReasonIntoInternal(output, -1); } // This will expand the reason of the IntegerLiteral already in tmp_queue_ until // everything is explained in term of Literal. -void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { +void IntegerTrail::MergeReasonIntoInternal(std::vector* output, + int64_t conflict_id) const { // All relevant trail indices will be >= var_lbs_.size(), so we can safely use // zero to means that no literal referring to this variable is in the queue. DCHECK(std::all_of(tmp_var_to_trail_index_in_queue_.begin(), tmp_var_to_trail_index_in_queue_.end(), [](int v) { return v == 0; })); + DCHECK(tmp_to_clear_.empty()); + + info_is_valid_on_subsequent_last_level_expansion_ = true; + if (conflict_id == -1 || last_conflict_id_ != conflict_id) { + // New conflict or a reason was asked outside first UIP resolution. + // We just clear everything. + last_conflict_id_ = conflict_id; + for (const IntegerVariable var : to_clear_for_lower_level_) { + var_to_trail_index_at_lower_level_[var] = 0; + } + to_clear_for_lower_level_.clear(); + } + + const int last_decision_index = + integer_search_levels_.empty() || conflict_id == -1 + ? 0 + : integer_search_levels_.back(); added_variables_.ClearAndResize(BooleanVariable(trail_->NumVariables())); for (const Literal l : *output) { @@ -1860,8 +1886,9 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { // We process the entries by highest trail_index first. The content of the // queue will always be a valid reason for the literals we already added to // the output. - tmp_to_clear_.clear(); + int64_t work_done = 0; while (!tmp_queue_.empty()) { + ++work_done; const int trail_index = tmp_queue_.front(); const TrailEntry& entry = integer_trail_[trail_index]; std::pop_heap(tmp_queue_.begin(), tmp_queue_.end()); @@ -1874,6 +1901,35 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { continue; } + // Process this entry. Note that if any of the next expansion include the + // variable entry.var in their reason, we must process it again because we + // cannot easily detect if it was needed to infer the current entry. + // + // Important: the queue might already contains entries referring to the same + // variable. The code act like if we deleted all of them at this point, we + // just do that lazily. tmp_var_to_trail_index_in_queue_[var] will + // only refer to newly added entries. + // + // TODO(user): We can and should reset that to the initial value from + // var_to_trail_index_at_lower_level_ instead of zero. + tmp_var_to_trail_index_in_queue_[entry.var] = 0; + has_dependency_ = false; + + // Skip entries that we known are already explained by the part of the + // conflict not involving the last level. + if (var_to_trail_index_at_lower_level_[entry.var] >= trail_index) { + continue; + } + + // If this literal is not at the highest level, it will always be + // propagated by the current conflict (even after some 1-UIP resolution + // step). We save this fact so that future MergeReasonIntoInternal() on + // the same conflict can just avoid to expand integer literal that are + // already known to be implied. + if (trail_index < last_decision_index) { + tmp_seen_.push_back(trail_index); + } + // Set the cache threshold. Since we process trail indices in decreasing // order and we only have single linked list, we only want to advance the // "cache" up to this threshold. @@ -1888,7 +1944,7 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { if (associated_lit != kNoLiteralIndex) { // We check that the reason is the same! const int reason_index = integer_trail_[trail_index].reason_index; - CHECK_NE(reason_index, -1); + CHECK_GE(reason_index, 0); { const int start = literals_reason_starts_[reason_index]; const int end = reason_index + 1 < literals_reason_starts_.size() @@ -1915,21 +1971,11 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { } } - // Process this entry. Note that if any of the next expansion include the - // variable entry.var in their reason, we must process it again because we - // cannot easily detect if it was needed to infer the current entry. - // - // Important: the queue might already contains entries referring to the same - // variable. The code act like if we deleted all of them at this point, we - // just do that lazily. tmp_var_to_trail_index_in_queue_[var] will - // only refer to newly added entries. - tmp_var_to_trail_index_in_queue_[entry.var] = 0; - has_dependency_ = false; - - ComputeLazyReasonIfNeeded(trail_index); - AppendLiteralsReason(trail_index, output); - for (const int next_trail_index : Dependencies(trail_index)) { - if (next_trail_index < 0) break; + ComputeLazyReasonIfNeeded(entry.reason_index); + AppendLiteralsReason(entry.reason_index, output); + const auto dependencies = Dependencies(entry.reason_index); + work_done += dependencies.size(); + for (const int next_trail_index : dependencies) { DCHECK_LT(next_trail_index, trail_index); const TrailEntry& next_entry = integer_trail_[next_trail_index]; @@ -1939,8 +1985,24 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { // in the queue referring to the same variable. const int index_in_queue = tmp_var_to_trail_index_in_queue_[next_entry.var]; - if (index_in_queue != std::numeric_limits::max()) - has_dependency_ = true; + + // This means the integer literal had no dependency and is already + // explained by the literal we added. + if (index_in_queue >= trail_index) { + // Disable the other optim if we might expand this literal during + // 1-UIP resolution. + if (index_in_queue >= last_decision_index) { + info_is_valid_on_subsequent_last_level_expansion_ = false; + } + continue; + } + + if (next_trail_index <= + var_to_trail_index_at_lower_level_[next_entry.var]) { + continue; + } + + has_dependency_ = true; if (next_trail_index > index_in_queue) { tmp_var_to_trail_index_in_queue_[next_entry.var] = next_trail_index; tmp_queue_.push_back(next_trail_index); @@ -1948,37 +2010,55 @@ void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { } } - // Special case for a "leaf", we will never need this variable again. + // Special case for a "leaf", we will never need this variable again in the + // current explanation. if (!has_dependency_) { tmp_to_clear_.push_back(entry.var); - tmp_var_to_trail_index_in_queue_[entry.var] = - std::numeric_limits::max(); + tmp_var_to_trail_index_in_queue_[entry.var] = trail_index; + } + } + + // Update var_to_trail_index_at_lower_level_. + if (info_is_valid_on_subsequent_last_level_expansion_) { + for (const int trail_index : tmp_seen_) { + if (trail_index == 0) continue; + const TrailEntry& entry = integer_trail_[trail_index]; + const int old = var_to_trail_index_at_lower_level_[entry.var]; + if (old == 0) { + to_clear_for_lower_level_.push_back(entry.var); + } + var_to_trail_index_at_lower_level_[entry.var] = + std::max(old, trail_index); } } + tmp_seen_.clear(); // clean-up. for (const IntegerVariable var : tmp_to_clear_) { tmp_var_to_trail_index_in_queue_[var] = 0; } + tmp_to_clear_.clear(); + + time_limit_->AdvanceDeterministicTime(work_done * 5e-9); } // TODO(user): If this is called many time on the same variables, it could be // made faster by using some caching mechanism. absl::Span IntegerTrail::Reason(const Trail& trail, - int trail_index) const { - const int index = boolean_trail_index_to_integer_one_[trail_index]; + int trail_index, + int64_t conflict_id) const { std::vector* reason = trail.GetEmptyVectorToStoreReason(trail_index); added_variables_.ClearAndResize(BooleanVariable(trail_->NumVariables())); - ComputeLazyReasonIfNeeded(index); - AppendLiteralsReason(index, reason); + const int reason_index = boolean_trail_index_to_reason_index_[trail_index]; + ComputeLazyReasonIfNeeded(reason_index); + AppendLiteralsReason(reason_index, reason); DCHECK(tmp_queue_.empty()); - for (const int prev_trail_index : Dependencies(index)) { - if (prev_trail_index < 0) break; + for (const int prev_trail_index : Dependencies(reason_index)) { DCHECK_GE(prev_trail_index, var_lbs_.size()); tmp_queue_.push_back(prev_trail_index); } - MergeReasonIntoInternal(reason); + MergeReasonIntoInternal(reason, conflict_id); return *reason; } @@ -2125,7 +2205,7 @@ bool GenericLiteralWatcher::Propagate(Trail* trail) { // Before we propagate, make sure any reversible structure are up to date. // Note that we never do anything expensive more than once per level. - { + if (id_need_reversible_support_[id]) { const int low = id_to_greatest_common_level_since_last_call_[IdType(id)]; const int high = id_to_level_at_last_call_[id]; @@ -2252,10 +2332,13 @@ void GenericLiteralWatcher::Untrail(const Trail& trail, int trail_index) { int GenericLiteralWatcher::Register(PropagatorInterface* propagator) { const int id = watchers_.size(); watchers_.push_back(propagator); + + id_need_reversible_support_.push_back(false); id_to_level_at_last_call_.push_back(0); id_to_greatest_common_level_since_last_call_.GrowByOne(); id_to_reversible_classes_.push_back(std::vector()); id_to_reversible_ints_.push_back(std::vector()); + id_to_watch_indices_.push_back(std::vector()); id_to_priority_.push_back(1); id_to_idempotence_.push_back(true); @@ -2290,10 +2373,12 @@ void GenericLiteralWatcher::AlwaysCallAtLevelZero(int id) { void GenericLiteralWatcher::RegisterReversibleClass(int id, ReversibleInterface* rev) { + id_need_reversible_support_[id] = true; id_to_reversible_classes_[id].push_back(rev); } void GenericLiteralWatcher::RegisterReversibleInt(int id, int* rev) { + id_need_reversible_support_[id] = true; id_to_reversible_ints_[id].push_back(rev); } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 95f323abd6d..3923dcb89d5 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -752,22 +752,18 @@ class LazyReasonInterface { LazyReasonInterface() = default; virtual ~LazyReasonInterface() = default; - // The function is provided with the IntegerLiteral to explain and its index - // in the integer trail. It must fill the two vectors so that literals - // contains any Literal part of the reason and dependencies contains the trail - // index of any IntegerLiteral that is also part of the reason. + // When called, this must fill the two vectors so that literals contains any + // Literal part of the reason and dependencies contains the trail index of any + // IntegerLiteral that is also part of the reason. // - // Remark: sometimes this is called to fill the conflict while the literal to - // explain is propagated. In this case, trail_index will be the current trail - // index, and we cannot assume that there is anything filled yet in - // integer_literal[trail_index]. + // Remark: integer_literal[trail_index] might not exist or has nothing to + // do with what was propagated. // - // TODO(user): Right now this is only used by "linear" propagator, if we need - // more we could replace {id, propagation_slack} by a generic payload so that - // each implementation can cast it to its need. Then the memory will just be - // the max size of this payload data (16 bytes should be fine). + // TODO(user): {id, propagation_slack, var_to_explain, trail_index} is just a + // generic "payload" and we should probably rename it as such so that each + // implementation can store different things. virtual void Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) = 0; }; @@ -784,6 +780,7 @@ class IntegerTrail final : public SatPropagator { encoder_(model->GetOrCreate()), trail_(model->GetOrCreate()), sat_solver_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()), parameters_(*model->GetOrCreate()) { model->GetOrCreate()->AddPropagator(this); } @@ -799,8 +796,8 @@ class IntegerTrail final : public SatPropagator { // correct state before calling any of its functions. bool Propagate(Trail* trail) final; void Untrail(const Trail& trail, int literal_trail_index) final; - absl::Span Reason(const Trail& trail, - int trail_index) const final; + absl::Span Reason(const Trail& trail, int trail_index, + int64_t conflict_id) const final; // Returns the number of created integer variables. // @@ -1024,10 +1021,8 @@ class IntegerTrail final : public SatPropagator { IntegerLiteral i_lit, int id, IntegerValue propagation_slack, LazyReasonInterface* explainer) { const int trail_index = integer_trail_.size(); - if (trail_index >= lazy_reasons_.size()) { - lazy_reasons_.resize(trail_index + 1); - } - lazy_reasons_[trail_index] = {explainer, propagation_slack, id}; + lazy_reasons_.push_back(LazyReasonEntry{explainer, propagation_slack, + i_lit.var, id, trail_index}); return EnqueueInternal(i_lit, true, {}, {}, 0); } @@ -1153,6 +1148,47 @@ class IntegerTrail final : public SatPropagator { debug_checker_ = std::move(checker); } + // This is used by the GreaterThanAtLeastOneOf() lazy reason. + // + // TODO(user): This might better lives together with the propagation code, + // but it does need access to data about the reason/conflict being currently + // computed. Also for speed we do need all the code here in on block. Given + // than we have just a few "lazy integer reason", we might not really want a + // generic code in any case. + void AddAllGreaterThanConstantReason(absl::Span exprs, + IntegerValue target_min, + std::vector* indices) const { + for (const AffineExpression& expr : exprs) { + if (expr.IsConstant()) { + DCHECK_GE(expr.constant, target_min); + continue; + } + DCHECK_NE(expr.var, kNoIntegerVariable); + + // Skip if we already have an explanation for expr >= target_min. Note + // that we already do that while processing the returned indices, so this + // mainly save a FindLowestTrailIndexThatExplainBound() call per skipped + // indices, which can still be costly. + { + const int index = tmp_var_to_trail_index_in_queue_[expr.var]; + if (index == std::numeric_limits::max()) continue; + if (index > 0 && + expr.ValueAt(integer_trail_[index].bound) >= target_min) { + has_dependency_ = true; + continue; + } + } + + // We need to find the index that explain the bound. + // Note that this will skip if the condition is true at level zero. + const int index = + FindLowestTrailIndexThatExplainBound(expr.GreaterOrEqual(target_min)); + if (index >= 0) { + indices->push_back(index); + } + } + } + private: // Used for DHECKs to validate the reason given to the public functions above. // Tests that all Literal are false. Tests that all IntegerLiteral are true. @@ -1202,7 +1238,8 @@ class IntegerTrail final : public SatPropagator { IntegerLiteral i_lit, Literal literal_reason); // Does the work of MergeReasonInto() when queue_ is already initialized. - void MergeReasonIntoInternal(std::vector* output) const; + void MergeReasonIntoInternal(std::vector* output, + int64_t conflict_id) const; // Returns the lowest trail index of a TrailEntry that can be used to explain // the given IntegerLiteral. The literal must be currently true (CHECKed). @@ -1212,25 +1249,28 @@ class IntegerTrail final : public SatPropagator { // This must be called before Dependencies() or AppendLiteralsReason(). // // TODO(user): Not really robust, try to find a better way. - void ComputeLazyReasonIfNeeded(int trail_index) const; + void ComputeLazyReasonIfNeeded(int reason_index) const; // Helper function to return the "dependencies" of a bound assignment. // All the TrailEntry at these indices are part of the reason for this // assignment. // // Important: The returned Span is only valid up to the next call. - absl::Span Dependencies(int trail_index) const; + absl::Span Dependencies(int reason_index) const; // Helper function to append the Literal part of the reason for this bound // assignment. We use added_variables_ to not add the same literal twice. // Note that looking at literal.Variable() is enough since all the literals // of a reason must be false. - void AppendLiteralsReason(int trail_index, + void AppendLiteralsReason(int reason_index, std::vector* output) const; // Returns some debugging info. std::string DebugString(); + // Used internally to return the next conlict number. + int64_t NextConflictId(); + // Information for each integer variable about its current lower bound and // position of the last TrailEntry in the trail referring to this var. util_intops::StrongVector var_lbs_; @@ -1257,9 +1297,8 @@ class IntegerTrail final : public SatPropagator { IntegerVariable var; int32_t prev_trail_index; - // Index in literals_reason_start_/bounds_reason_starts_ If this is -1, then - // this was a propagation with a lazy reason, and the reason can be - // re-created by calling the function lazy_reasons_[trail_index]. + // Index in literals_reason_start_/bounds_reason_starts_ If this is negative + // then it is a lazy reason. int32_t reason_index; }; std::vector integer_trail_; @@ -1267,15 +1306,18 @@ class IntegerTrail final : public SatPropagator { struct LazyReasonEntry { LazyReasonInterface* explainer; IntegerValue propagation_slack; + IntegerVariable var_to_explain; int id; + int trail_index_at_propagation_time; - void Explain(IntegerLiteral literal_to_explain, int trail_index_of_literal, - std::vector* literals, + void Explain(std::vector* literals, std::vector* dependencies) const { - explainer->Explain(id, propagation_slack, literal_to_explain, - trail_index_of_literal, literals, dependencies); + explainer->Explain(id, propagation_slack, var_to_explain, + trail_index_at_propagation_time, literals, + dependencies); } }; + std::vector lazy_reason_decision_levels_; std::vector lazy_reasons_; // Start of each decision levels in integer_trail_. @@ -1283,18 +1325,15 @@ class IntegerTrail final : public SatPropagator { std::vector integer_search_levels_; // Buffer to store the reason of each trail entry. - // Note that bounds_reason_buffer_ is an "union". It initially contains the - // IntegerLiteral, and is lazily replaced by the result of - // FindLowestTrailIndexThatExplainBound() applied to these literals. The - // encoding is a bit hacky, see Dependencies(). std::vector reason_decision_levels_; std::vector literals_reason_starts_; - std::vector bounds_reason_starts_; std::vector literals_reason_buffer_; - // These two vectors are in one to one correspondence. Dependencies() will + // The last two vectors are in one to one correspondence. Dependencies() will // "cache" the result of the conversion from IntegerLiteral to trail indices // in trail_index_reason_buffer_. + std::vector bounds_reason_starts_; + mutable std::vector cached_sizes_; std::vector bounds_reason_buffer_; mutable std::vector trail_index_reason_buffer_; @@ -1326,16 +1365,23 @@ class IntegerTrail final : public SatPropagator { // Temporary data used by SafeEnqueue(); std::vector tmp_cleaned_reason_; - // For EnqueueLiteral(), we store a special TrailEntry to recover the reason - // lazily. This vector indicates the correspondence between a literal that - // was pushed by this class at a given trail index, and the index of its - // TrailEntry in integer_trail_. - std::vector boolean_trail_index_to_integer_one_; + // For EnqueueLiteral(), we store the reason index at its Boolean trail index. + std::vector boolean_trail_index_to_reason_index_; // We need to know if we skipped some propagation in the current branch. // This is reverted as we backtrack over it. int first_level_without_full_propagation_ = -1; + // This is used to detect when MergeReasonIntoInternal() is called multiple + // time while processing the same conflict. It allows to optimize the reason + // and the time taken to compute it. + mutable int64_t last_conflict_id_ = -1; + mutable bool info_is_valid_on_subsequent_last_level_expansion_ = false; + mutable util_intops::StrongVector + var_to_trail_index_at_lower_level_; + mutable std::vector tmp_seen_; + mutable std::vector to_clear_for_lower_level_; + int64_t num_enqueues_ = 0; int64_t num_untrails_ = 0; int64_t num_level_zero_enqueues_ = 0; @@ -1350,6 +1396,7 @@ class IntegerTrail final : public SatPropagator { IntegerEncoder* encoder_; Trail* trail_; SatSolver* sat_solver_; + TimeLimit* time_limit_; const SatParameters& parameters_; // Temporary "hash" to keep track of all the conditional enqueue that were @@ -1587,10 +1634,12 @@ class GenericLiteralWatcher final : public SatPropagator { // Data for each propagator. DEFINE_STRONG_INDEX_TYPE(IdType); + std::vector id_need_reversible_support_; std::vector id_to_level_at_last_call_; RevVector id_to_greatest_common_level_since_last_call_; std::vector> id_to_reversible_classes_; std::vector> id_to_reversible_ints_; + std::vector> id_to_watch_indices_; std::vector id_to_priority_; std::vector id_to_idempotence_; diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 25f8092195a..9d9924f8963 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -229,16 +229,15 @@ LinearConstraintPropagator::ConditionalLb( template void LinearConstraintPropagator::Explain( - int /*id*/, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, int trail_index, - std::vector* literals_reason, + int /*id*/, IntegerValue propagation_slack, IntegerVariable var_to_explain, + int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) { *literals_reason = literal_reason_; trail_indices_reason->clear(); shared_->reason_coeffs.clear(); for (int i = 0; i < size_; ++i) { const IntegerVariable var = vars_[i]; - if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + if (PositiveVariable(var) == PositiveVariable(var_to_explain)) { continue; } const int index = @@ -653,8 +652,7 @@ LinMinPropagator::LinMinPropagator(const std::vector& exprs, integer_trail_(model_->GetOrCreate()) {} void LinMinPropagator::Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, - int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) { const auto& vars = exprs_[id].vars; @@ -665,7 +663,7 @@ void LinMinPropagator::Explain(int id, IntegerValue propagation_slack, const int size = vars.size(); for (int i = 0; i < size; ++i) { const IntegerVariable var = vars[i]; - if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + if (PositiveVariable(var) == PositiveVariable(var_to_explain)) { continue; } const int index = diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 104cfeeb213..20d4619706b 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -102,7 +102,7 @@ class LinearConstraintPropagator : public PropagatorInterface, // For LazyReasonInterface. void Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) final; @@ -252,7 +252,7 @@ class LinMinPropagator : public PropagatorInterface, LazyReasonInterface { // For LazyReasonInterface. void Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) final; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 92dadbd73be..1305be482e6 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -17,9 +17,7 @@ #include #include #include -#include #include -#include #include #include @@ -28,7 +26,6 @@ #include "absl/meta/type_traits.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/sat/clause.h" @@ -172,7 +169,7 @@ IntegerLiteral SplitUsingBestSolutionValueInRepository( std::function FirstUnassignedVarAtItsMinHeuristic( const std::vector& vars, Model* model) { auto* integer_trail = model->GetOrCreate(); - return [/*copy*/ vars = vars, integer_trail]() { + return [/*copy*/ vars, integer_trail]() { for (const IntegerVariable var : vars) { const IntegerLiteral decision = AtMinValue(var, integer_trail); if (decision.IsValid()) return BooleanOrIntegerLiteral(decision); @@ -257,7 +254,7 @@ std::function LpPseudoCostHeuristic(Model* model) { if (info.score > best_score) { best_score = info.score; - // This direction works better than the inverse in the benchmarks. But + // This direction works better than the inverse in the benchs. But // always branching up seems even better. TODO(user): investigate. if (info.down_score > info.up_score) { decision = BooleanOrIntegerLiteral(info.down_branch); @@ -1025,8 +1022,8 @@ std::function RandomizeOnRestartHeuristic( // Special case: Don't change the decision value. value_selection_weight.push_back(10); - // TODO(user): These distribution values are just guessed values. - // They need to be tuned. + // TODO(user): These distribution values are just guessed values. They + // need to be tuned. std::discrete_distribution val_dist(value_selection_weight.begin(), value_selection_weight.end()); @@ -1117,9 +1114,9 @@ std::function FollowHint( auto* integer_trail = model->GetOrCreate(); auto* rev_int_repo = model->GetOrCreate(); - // This is not ideal as we reserve an int for the full duration of the - // model even if we use this FollowHint() function just for a while. But - // it is an easy solution to not have reference to deleted memory in the + // This is not ideal as we reserve an int for the full duration of the model + // even if we use this FollowHint() function just for a while. But it is + // an easy solution to not have reference to deleted memory in the // RevIntRepository(). Note that once we backtrack, these reference will // disappear. int* rev_start_index = model->TakeOwnership(new int); @@ -1459,15 +1456,15 @@ SatSolver::Status IntegerSearchHelper::SolveIntegerProblem() { CHECK_EQ(num_policies, heuristics.restart_policies.size()); // Note that it is important to do the level-zero propagation if it wasn't - // already done because EnqueueDecisionAndBackjumpOnConflict() assumes - // that the solver is in a "propagated" state. + // already done because EnqueueDecisionAndBackjumpOnConflict() assumes that + // the solver is in a "propagated" state. // - // TODO(user): We have the issue that at level zero. calling the - // propagation loop more than once can propagate more! This is because we - // call the LP again and again on each level zero propagation. This is - // causing some CHECKs() to fail in multithread (rarely) because when we - // associate new literals to integer ones, Propagate() is indirectly - // called. Not sure yet how to fix. + // TODO(user): We have the issue that at level zero. calling the propagation + // loop more than once can propagate more! This is because we call the LP + // again and again on each level zero propagation. This is causing some + // CHECKs() to fail in multithread (rarely) because when we associate new + // literals to integer ones, Propagate() is indirectly called. Not sure yet + // how to fix. if (!sat_solver_->FinishPropagation()) return sat_solver_->UnsatStatus(); // Main search loop. @@ -1511,21 +1508,20 @@ SatSolver::Status IntegerSearchHelper::SolveIntegerProblem() { // No decision means that we reached a leave of the search tree and that // we have a feasible solution. // - // Tricky: If the time limit is reached during the final propagation - // when all variables are fixed, there is no guarantee that the - // propagation responsible for testing the validity of the solution was - // run to completion. So we cannot report a feasible solution. + // Tricky: If the time limit is reached during the final propagation when + // all variables are fixed, there is no guarantee that the propagation + // responsible for testing the validity of the solution was run to + // completion. So we cannot report a feasible solution. if (time_limit_->LimitReached()) return SatSolver::LIMIT_REACHED; if (decision == kNoLiteralIndex) { - // Save the current polarity of all Booleans in the solution. It will - // be followed for the next SAT decisions. This is known to be a good - // policy for optimization problem. Note that for decision problem we - // don't care since we are just done as soon as a solution is found. + // Save the current polarity of all Booleans in the solution. It will be + // followed for the next SAT decisions. This is known to be a good policy + // for optimization problem. Note that for decision problem we don't care + // since we are just done as soon as a solution is found. // // This idea is kind of "well known", see for instance the "LinSBPS" - // submission to the maxSAT 2018 competition by Emir Demirovic and - // Peter Stuckey where they show it is a good idea and provide more - // references. + // submission to the maxSAT 2018 competition by Emir Demirovic and Peter + // Stuckey where they show it is a good idea and provide more references. if (parameters_.use_optimization_hints()) { auto* sat_decision = model_->GetOrCreate(); const auto& trail = *model_->GetOrCreate(); @@ -1540,9 +1536,9 @@ SatSolver::Status IntegerSearchHelper::SolveIntegerProblem() { return sat_solver_->UnsatStatus(); } - // In multi-thread, we really only want to save the LP relaxation for - // thread with high linearization level to avoid to pollute the - // repository with sub-par lp solutions. + // In multi-thread, we really only want to save the LP relaxation for thread + // with high linearization level to avoid to pollute the repository with + // sub-par lp solutions. // // TODO(user): Experiment more around dynamically changing the // threshold for storing LP solutions in the pool. Alternatively expose @@ -1556,10 +1552,9 @@ SatSolver::Status IntegerSearchHelper::SolveIntegerProblem() { parameters_.linearization_level() >= 2) { num_decisions_since_last_lp_record_++; if (num_decisions_since_last_lp_record_ >= 100) { - // NOTE: We can actually record LP solutions more frequently. - // However this process is time consuming and workers waste a lot of - // time doing this. To avoid this we don't record solutions after - // each decision. + // NOTE: We can actually record LP solutions more frequently. However + // this process is time consuming and workers waste a lot of time doing + // this. To avoid this we don't record solutions after each decision. RecordLPRelaxationValues(model_); num_decisions_since_last_lp_record_ = 0; } diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index 51e608ef1a1..5a3d5d64b4f 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -29,6 +29,7 @@ #include #include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/implied_bounds.h" diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index c61610b7271..984559e1d37 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -229,6 +229,7 @@ SchedulingConstraintHelper::SchedulingConstraintHelper( const std::vector& tasks, Model* model) : model_(model), trail_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), watcher_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), @@ -273,6 +274,7 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, Model* model) : model_(model), trail_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), capacity_(num_tasks), @@ -300,19 +302,6 @@ bool SchedulingConstraintHelper::IncrementalPropagate( return true; } -void SchedulingConstraintHelper::SetLevel(int level) { - // If there was an Untrail before, we need to refresh the cache so that - // we never have value from lower in the search tree. - // - // TODO(user): We could be smarter here, but then this is not visible in our - // cpu_profile since we call many times IncrementalPropagate() for each new - // decision, but just call Propagate() once after each Untrail(). - if (level < previous_level_) { - recompute_all_cache_ = true; - } - previous_level_ = level; -} - void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); const int num_tasks = starts_.size(); @@ -322,10 +311,6 @@ void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { watcher->WatchIntegerVariable(ends_[t].var, id, t); } watcher->SetPropagatorPriority(id, 0); - - // Note that it is important to register with the integer_trail_ so we are - // ALWAYS called before any propagator that depends on this helper. - integer_trail_->RegisterReversibleClass(this); } bool SchedulingConstraintHelper::UpdateCachedValues(int t) { @@ -498,6 +483,14 @@ void SchedulingConstraintHelper::SetTimeDirection(bool is_forward) { bool SchedulingConstraintHelper::SynchronizeAndSetTimeDirection( bool is_forward) { SetTimeDirection(is_forward); + + // If there was any backtracks since the last time this was called, we + // recompute our cache. + if (sat_solver_->num_backtracks() != saved_num_backtracks_) { + recompute_all_cache_ = true; + saved_num_backtracks_ = sat_solver_->num_backtracks(); + } + if (recompute_all_cache_) { for (int t = 0; t < recompute_cache_.size(); ++t) { if (!UpdateCachedValues(t)) return false; diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 31869bbe8cb..65157b56e1a 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -261,8 +261,7 @@ struct CachedTaskBounds { // One of the main advantage of this class is that it allows to share the // vectors of tasks sorted by various criteria between propagator for a faster // code. -class SchedulingConstraintHelper : public PropagatorInterface, - ReversibleInterface { +class SchedulingConstraintHelper : public PropagatorInterface { public: // All the functions below refer to a task by its index t in the tasks // vector given at construction. @@ -284,7 +283,6 @@ class SchedulingConstraintHelper : public PropagatorInterface, bool Propagate() final; bool IncrementalPropagate(const std::vector& watch_indices) final; void RegisterWith(GenericLiteralWatcher* watcher); - void SetLevel(int level) final; // Resets the class to the same state as if it was constructed with // the given subset of tasks from other. @@ -545,6 +543,7 @@ class SchedulingConstraintHelper : public PropagatorInterface, Model* model_; Trail* trail_; + SatSolver* sat_solver_; IntegerTrail* integer_trail_; GenericLiteralWatcher* watcher_; PrecedenceRelations* precedence_relations_; @@ -565,8 +564,8 @@ class SchedulingConstraintHelper : public PropagatorInterface, std::vector minus_starts_; std::vector minus_ends_; - // This is used by SetLevel() to detect untrail. - int previous_level_ = 0; + // This is used to detect when we need to invalidate the cache. + int64_t saved_num_backtracks_ = 0; // The caches of all relevant interval values. // These are initially of size capacity and never resized. diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index db1b60f13df..99299211b48 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -886,8 +886,7 @@ bool LinearPropagator::PropagateInfeasibleConstraint(int id, } void LinearPropagator::Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, - int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) { literals_reason->clear(); @@ -900,7 +899,7 @@ void LinearPropagator::Explain(int id, IntegerValue propagation_slack, const auto vars = GetVariables(info); for (int i = 0; i < info.initial_size; ++i) { const IntegerVariable var = vars[i]; - if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + if (PositiveVariable(var) == PositiveVariable(var_to_explain)) { continue; } const int index = diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index 20e470935d6..ff8927ffd1f 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -317,7 +317,7 @@ class LinearPropagator : public PropagatorInterface, // For LazyReasonInterface. void Explain(int id, IntegerValue propagation_slack, - IntegerLiteral literal_to_explain, int trail_index, + IntegerVariable var_to_explain, int trail_index, std::vector* literals_reason, std::vector* trail_indices_reason) final; diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index 0595ec66dce..847d64cf082 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -856,7 +856,7 @@ ConstraintProto* ConstraintScaler::AddConstraint( } // TODO(user): unit test this. -double FindFractionalScaling(const std::vector& coefficients, +double FindFractionalScaling(absl::Span coefficients, double tolerance) { double multiplier = 1.0; for (const double coeff : coefficients) { diff --git a/ortools/sat/pb_constraint.cc b/ortools/sat/pb_constraint.cc index c60e7438635..151fa3d814d 100644 --- a/ortools/sat/pb_constraint.cc +++ b/ortools/sat/pb_constraint.cc @@ -14,6 +14,7 @@ #include "ortools/sat/pb_constraint.h" #include +#include #include #include #include @@ -976,7 +977,8 @@ void PbConstraints::Untrail(const Trail& trail, int trail_index) { } absl::Span PbConstraints::Reason(const Trail& trail, - int trail_index) const { + int trail_index, + int64_t /*conflict_id*/) const { SCOPED_TIME_STAT(&stats_); const PbConstraintsEnqueueHelper::ReasonInfo& reason_info = enqueue_helper_.reasons[trail_index]; diff --git a/ortools/sat/pb_constraint.h b/ortools/sat/pb_constraint.h index acd164a377d..fb093e8f66c 100644 --- a/ortools/sat/pb_constraint.h +++ b/ortools/sat/pb_constraint.h @@ -553,8 +553,8 @@ class PbConstraints : public SatPropagator { bool Propagate(Trail* trail) final; void Untrail(const Trail& trail, int trail_index) final; - absl::Span Reason(const Trail& trail, - int trail_index) const final; + absl::Span Reason(const Trail& trail, int trail_index, + int64_t conflict_id) const final; // Changes the number of variables. void Resize(int num_variables) { diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index 1a19cab9e3f..e11c15b1f70 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -406,7 +406,7 @@ bool Prober::ProbeDnf(absl::string_view name, num_new_literals_fixed_ > previous_num_literals_fixed) { VLOG(1) << "ProbeDnf(" << name << ", num_fixed_literals=" << num_new_literals_fixed_ - previous_num_literals_fixed - << ", num_fixed_integer_bounds=" + << ", num_pushed_integer_bounds=" << num_new_integer_bounds_ - previous_num_integer_bounds << ", num_valid_conjunctions=" << num_valid_conjunctions << "/" << dnf.size() << ")"; diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 0f0d769735f..c342af7105f 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -365,7 +365,17 @@ class Trail { // Note that this shouldn't be called on a variable at level zero, because we // don't cleanup the reason data for these variables but the underlying // clauses may have been deleted. - absl::Span Reason(BooleanVariable var) const; + // + // If conflict_id >= 0, this indicate that this was called as part of the + // first-UIP procedure. It has a few implication: + // - The reason do not need to be cached and can be adapted to the current + // conflict. + // - Some data can be reused between two calls about the same conflict. + // - Note however that if the reason is a simple clause, we shouldn't adapt + // it because we rely on extra fact in the first UIP code where we detect + // subsumed clauses for instance. + absl::Span Reason(BooleanVariable var, + int64_t conflict_id = -1) const; // Returns the "type" of an assignment (see AssignmentType). Note that this // function never returns kSameReasonAs or kCachedReason, it instead returns @@ -568,8 +578,13 @@ class SatPropagator { // The returned Span has to be valid until the literal is untrailed. A client // can use trail_.GetEmptyVectorToStoreReason() if it doesn't have a memory // location that already contains the reason. + // + // If conlict id is positive, then this is called during first UIP resolution + // and we will backtrack over this literal right away, so we don't need to + // have a span that survive more than once. virtual absl::Span Reason(const Trail& /*trail*/, - int /*trail_index*/) const { + int /*trail_index*/, + int64_t /*conflict_id*/) const { LOG(FATAL) << "Not implemented."; return {}; } @@ -662,7 +677,8 @@ inline int Trail::AssignmentType(BooleanVariable var) const { return type != AssignmentType::kCachedReason ? type : old_type_[var]; } -inline absl::Span Trail::Reason(BooleanVariable var) const { +inline absl::Span Trail::Reason(BooleanVariable var, + int64_t conflict_id) const { // Special case for AssignmentType::kSameReasonAs to avoid a recursive call. var = ReferenceVarWithSameReason(var); @@ -684,7 +700,8 @@ inline absl::Span Trail::Reason(BooleanVariable var) const { } else { DCHECK_LT(info.type, propagators_.size()); DCHECK(propagators_[info.type] != nullptr) << info.type; - reasons_[var] = propagators_[info.type]->Reason(*this, info.trail_index); + reasons_[var] = + propagators_[info.type]->Reason(*this, info.trail_index, conflict_id); } old_type_[var] = info.type; info_[var].type = AssignmentType::kCachedReason; diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 36c49a210eb..4babf78c7cd 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -108,6 +108,8 @@ int64_t SatSolver::num_propagations() const { return trail_->NumberOfEnqueues() - counters_.num_branches; } +int64_t SatSolver::num_backtracks() const { return counters_.num_backtracks; } + int64_t SatSolver::num_restarts() const { return counters_.num_restarts; } double SatSolver::deterministic_time() const { @@ -1044,6 +1046,7 @@ void SatSolver::Backtrack(int target_level) { DCHECK_LE(target_level, CurrentDecisionLevel()); // Any backtrack to the root from a positive one is counted as a restart. + counters_.num_backtracks++; if (target_level == 0) counters_.num_restarts++; // Per the SatPropagator interface, this is needed before calling Untrail. @@ -2024,6 +2027,7 @@ void SatSolver::ComputeFirstUIPConflict( std::vector* reason_used_to_infer_the_conflict, std::vector* subsumed_clauses) { SCOPED_TIME_STAT(&stats_); + const int64_t conflict_id = counters_.num_failures; // This will be used to mark all the literals inspected while we process the // conflict and the reasons behind each of its variable assignments. @@ -2131,7 +2135,7 @@ void SatSolver::ComputeFirstUIPConflict( literal.Variable()) != literal.Variable()) { clause_to_expand = {}; } else { - clause_to_expand = trail_->Reason(literal.Variable()); + clause_to_expand = trail_->Reason(literal.Variable(), conflict_id); } sat_clause = ReasonClauseOrNull(literal.Variable()); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index cb21c724e38..b8d2deaaff5 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -391,6 +391,7 @@ class SatSolver { int64_t num_branches() const; int64_t num_failures() const; int64_t num_propagations() const; + int64_t num_backtracks() const; // Note that we count the number of backtrack to level zero from a positive // level. Those can corresponds to actual restarts, or conflicts that learn @@ -403,6 +404,7 @@ class SatSolver { int64_t num_branches = 0; int64_t num_failures = 0; int64_t num_restarts = 0; + int64_t num_backtracks = 0; // Minimization stats. int64_t num_minimizations = 0; diff --git a/ortools/sat/scheduling_cuts.cc b/ortools/sat/scheduling_cuts.cc index dfb7b8aa868..ff1c03d7335 100644 --- a/ortools/sat/scheduling_cuts.cc +++ b/ortools/sat/scheduling_cuts.cc @@ -847,7 +847,7 @@ struct CachedIntervalData { }; void GenerateCutsBetweenPairOfNonOverlappingTasks( - const std::string& cut_name, + absl::string_view cut_name, const util_intops::StrongVector& lp_values, std::vector events, IntegerValue capacity_max, Model* model, LinearConstraintManager* manager) { @@ -1268,7 +1268,7 @@ void GenerateShortCompletionTimeCutsWithExactBound( // after a given start_min, sorted by relative (end_lp - start_min). // // TODO(user): merge with Packing cuts. -void GenerateCompletionTimeCutsWithEnergy(const std::string& cut_name, +void GenerateCompletionTimeCutsWithEnergy(absl::string_view cut_name, std::vector events, IntegerValue capacity_max, bool skip_low_sizes, Model* model, @@ -1407,7 +1407,7 @@ void GenerateCompletionTimeCutsWithEnergy(const std::string& cut_name, add_energy_to_name |= event.use_energy; cut.AddTerm(event.x_end, event.energy_min * best_capacity); } - std::string full_name = cut_name; + std::string full_name(cut_name); if (is_lifted) full_name.append("_lifted"); if (add_energy_to_name) full_name.append("_energy"); if (best_capacity < capacity_max) { diff --git a/ortools/sat/symmetry.cc b/ortools/sat/symmetry.cc index f302e92fc0f..e0acb7e3822 100644 --- a/ortools/sat/symmetry.cc +++ b/ortools/sat/symmetry.cc @@ -13,6 +13,7 @@ #include "ortools/sat/symmetry.h" +#include #include #include @@ -152,8 +153,8 @@ void SymmetryPropagator::Untrail(const Trail& trail, int trail_index) { } } -absl::Span SymmetryPropagator::Reason(const Trail& trail, - int trail_index) const { +absl::Span SymmetryPropagator::Reason( + const Trail& trail, int trail_index, int64_t /*conflict_id*/) const { SCOPED_TIME_STAT(&stats_); const ReasonInfo& reason_info = reasons_[trail_index]; std::vector* reason = trail.GetEmptyVectorToStoreReason(trail_index); diff --git a/ortools/sat/symmetry.h b/ortools/sat/symmetry.h index f7fff5ebe40..3e94bfd639e 100644 --- a/ortools/sat/symmetry.h +++ b/ortools/sat/symmetry.h @@ -69,8 +69,8 @@ class SymmetryPropagator : public SatPropagator { bool Propagate(Trail* trail) final; void Untrail(const Trail& trail, int trail_index) final; - absl::Span Reason(const Trail& trail, - int trail_index) const final; + absl::Span Reason(const Trail& trail, int trail_index, + int64_t conflict_id) const final; // Adds a new permutation to this symmetry propagator. The ownership is // transferred. This must be an integer permutation such that: diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 78d506ca659..edbc813bb1b 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -24,6 +24,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" @@ -425,25 +426,29 @@ double Percentile::GetPercentile(double percent) { CHECK_GT(records_.size(), 0); CHECK_LE(percent, 100.0); CHECK_GE(percent, 0.0); - std::vector sorted_records(records_.begin(), records_.end()); - std::sort(sorted_records.begin(), sorted_records.end()); - const int num_records = sorted_records.size(); + const int num_records = records_.size(); const double percentile_rank = static_cast(num_records) * percent / 100.0 - 0.5; if (percentile_rank <= 0) { - return sorted_records.front(); + return *absl::c_min_element(records_); } else if (percentile_rank >= num_records - 1) { - return sorted_records.back(); + return *absl::c_max_element(records_); } + std::vector sorted_records; + sorted_records.assign(records_.begin(), records_.end()); // Interpolate. DCHECK_GE(num_records, 2); DCHECK_LT(percentile_rank, num_records - 1); const int lower_rank = static_cast(std::floor(percentile_rank)); DCHECK_LT(lower_rank, num_records - 1); - return sorted_records[lower_rank] + - (percentile_rank - lower_rank) * - (sorted_records[lower_rank + 1] - sorted_records[lower_rank]); + auto upper_it = sorted_records.begin() + lower_rank + 1; + // Ensure that sorted_records[lower_rank + 1] is in the correct place as if + // records were actually sorted, the next closest is then the largest element + // to the left of this element. + absl::c_nth_element(sorted_records, upper_it); + auto lower_it = std::max_element(sorted_records.begin(), upper_it); + return *lower_it + (percentile_rank - lower_rank) * (*upper_it - *lower_it); } void CompressTuples(absl::Span domain_sizes, diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 9fe0a1548f7..9ccdf7d3095 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -25,6 +25,7 @@ #include #include +#include "absl/base/macros.h" #include "absl/container/btree_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -36,7 +37,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/logging.h" -#include "ortools/base/mathlimits.h" +#include "ortools/base/mathutil.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -582,7 +583,7 @@ class Percentile { // Returns number of stored records. int64_t NumRecords() const { return records_.size(); } - // Note that this is not fast and runs in O(n log n) for n records. + // Note that this runs in O(n) for n records. double GetPercentile(double percent); private: