Skip to content

Commit

Permalink
refactor state
Browse files Browse the repository at this point in the history
  • Loading branch information
drexlerd committed Nov 12, 2024
1 parent 9f44726 commit 7cc3d71
Show file tree
Hide file tree
Showing 30 changed files with 166 additions and 193 deletions.
4 changes: 2 additions & 2 deletions include/mimir/search/action.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class PDDLFactories;
* Flatmemory types
*/

/// @brief `FlatSimpleEffect` encapsulates the effect on a single grounded atom.
/// We cannot consistently use cista::tuple since nested tuples will automatically be flattened.
struct FlatSimpleEffect
{
bool is_negated;
uint32_t atom_index;

bool operator==(const FlatSimpleEffect& other) const;
};

using FlatStripsActionPrecondition = cista::tuple<FlatBitset, // positive static atoms
Expand Down
4 changes: 2 additions & 2 deletions include/mimir/search/algorithms/siw/goal_strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class ProblemGoalCounter : public IGoalStrategy

int m_initial_num_unsatisfied_goals;

int count_unsatisfied_goals(const State state) const;
int count_unsatisfied_goals(State state) const;

public:
explicit ProblemGoalCounter(Problem problem, State state);

bool test_static_goal() override;
bool test_dynamic_goal(const State state) override;
bool test_dynamic_goal(State state) override;
};
}

Expand Down
4 changes: 2 additions & 2 deletions include/mimir/search/algorithms/strategies/goal_strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class IGoalStrategy
virtual ~IGoalStrategy() = default;

virtual bool test_static_goal() = 0;
virtual bool test_dynamic_goal(const State state) = 0;
virtual bool test_dynamic_goal(State state) = 0;
};

class ProblemGoal : public IGoalStrategy
Expand All @@ -45,7 +45,7 @@ class ProblemGoal : public IGoalStrategy
explicit ProblemGoal(Problem problem);

bool test_static_goal() override;
bool test_dynamic_goal(const State state) override;
bool test_dynamic_goal(State state) override;
};
}

Expand Down
12 changes: 6 additions & 6 deletions include/mimir/search/algorithms/strategies/pruning_strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ class IPruningStrategy
public:
virtual ~IPruningStrategy() = default;

virtual bool test_prune_initial_state(const State state) = 0;
virtual bool test_prune_successor_state(const State state, const State succ_state, bool is_new_succ) = 0;
virtual bool test_prune_initial_state(State state) = 0;
virtual bool test_prune_successor_state(State state, State succ_state, bool is_new_succ) = 0;
};

class NoStatePruning : public IPruningStrategy
{
public:
bool test_prune_initial_state(const State state) override;
bool test_prune_successor_state(const State state, const State succ_state, bool is_new_succ) override;
bool test_prune_initial_state(State state) override;
bool test_prune_successor_state(State state, State succ_state, bool is_new_succ) override;
};

class DuplicateStatePruning : public IPruningStrategy
{
public:
bool test_prune_initial_state(const State state) override;
bool test_prune_successor_state(const State state, const State succ_state, bool is_new_succ) override;
bool test_prune_initial_state(State state) override;
bool test_prune_successor_state(State state, State succ_state, bool is_new_succ) override;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class GroundedApplicableActionGenerator : public IApplicableActionGenerator

void generate_applicable_actions(State state, GroundActionList& out_applicable_actions) override;

void generate_and_apply_axioms(StateBuilder& unextended_state) override;
void generate_and_apply_axioms(StateImpl& unextended_state) override;

void on_finish_search_layer() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class IApplicableActionGenerator
virtual void generate_applicable_actions(State state, GroundActionList& out_applicable_actions) = 0;

/// @brief Generate all applicable axioms for a given set of ground atoms by running fixed point computation.
virtual void generate_and_apply_axioms(StateBuilder& unextended_state) = 0;
virtual void generate_and_apply_axioms(StateImpl& unextended_state) = 0;

// Notify that a new f-layer was reached
virtual void on_finish_search_layer() const = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class LiftedApplicableActionGenerator : public IApplicableActionGenerator

void generate_applicable_actions(State state, GroundActionList& out_applicable_actions) override;

void generate_and_apply_axioms(StateBuilder& unextended_state) override;
void generate_and_apply_axioms(StateImpl& unextended_state) override;

void on_finish_search_layer() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class AxiomEvaluator
AxiomEvaluator& operator=(AxiomEvaluator&& other) = delete;

/// @brief Generate and apply all applicable axioms.
void generate_and_apply_axioms(StateBuilder& unextended_state);
void generate_and_apply_axioms(StateImpl& unextended_state);

/// @brief Return the axiom partitioning.
const std::vector<AxiomPartition>& get_axiom_partitioning() const;
Expand Down
3 changes: 2 additions & 1 deletion include/mimir/search/declarations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class GroundAction;
class GroundAxiom;

// State
class State;
struct StateImpl;
using State = const StateImpl*;

/* ApplicableActionGenerators */
class IApplicableActionGenerator;
Expand Down
58 changes: 23 additions & 35 deletions include/mimir/search/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,15 @@

namespace mimir
{
using FlatState = cista::tuple<Index, FlatBitset, FlatBitset>;

/// @brief `StateBuilder` encapsulates mutable data of a state.
class StateBuilder
/// @brief `StateImpl` encapsulates the fluent and derived atoms of a planning state.
/// We refer to the fluent atoms as the non-extended state
/// and the fluent and derived atoms as the extended state.
struct StateImpl
{
public:
explicit StateBuilder();

Index& get_index();

template<DynamicPredicateCategory P>
FlatBitset& get_atoms();

FlatState& get_data();
const FlatState& get_data() const;

private:
FlatState m_data;
};

/// @brief `State` is an immutable view on the data of a state.
class State
{
public:
explicit State(const FlatState& data);

Index get_index() const;
Index m_index;
FlatBitset m_fluent_atoms;
FlatBitset m_derived_atoms;

template<DynamicPredicateCategory P>
bool contains(GroundAtom<P> atom) const;
Expand All @@ -73,31 +55,37 @@ class State
template<DynamicPredicateCategory P>
bool literals_hold(const GroundLiteralList<P>& literals) const;

/* Getters */

Index& get_index();

Index get_index() const;

template<DynamicPredicateCategory P>
const FlatBitset& get_atoms() const;
FlatBitset& get_atoms();

private:
std::reference_wrapper<const FlatState> m_data;
template<DynamicPredicateCategory P>
const FlatBitset& get_atoms() const;
};

// Compare the state index, since states returned by the `StateRepository` are already unique by their index.
extern bool operator==(State lhs, State rhs);
extern bool operator!=(State lhs, State rhs);
extern bool operator==(const StateImpl& lhs, const StateImpl& rhs);
extern bool operator!=(const StateImpl& lhs, const StateImpl& rhs);

}

// Only hash/compare the non-extended portion of a state, and the problem.
// The extended portion is always equal for the same non-extended portion.
// We use it for the unique state construction in the `StateRepository`.
template<>
struct cista::storage::DerefStdHasher<mimir::FlatState>
struct cista::storage::DerefStdHasher<mimir::StateImpl>
{
size_t operator()(const mimir::FlatState* ptr) const;
size_t operator()(const mimir::StateImpl* ptr) const;
};
template<>
struct cista::storage::DerefStdEqualTo<mimir::FlatState>
struct cista::storage::DerefStdEqualTo<mimir::StateImpl>
{
bool operator()(const mimir::FlatState* lhs, const mimir::FlatState* rhs) const;
bool operator()(const mimir::StateImpl* lhs, const mimir::StateImpl* rhs) const;
};

// Hash the state index, since states returned by the `StateRepository` are already unique by their index.
Expand All @@ -110,7 +98,7 @@ struct std::hash<mimir::State>
namespace mimir
{

using FlatStateSet = cista::storage::UnorderedSet<FlatState>;
using FlatStateSet = cista::storage::UnorderedSet<StateImpl>;

using StateList = std::vector<State>;

Expand Down
2 changes: 1 addition & 1 deletion include/mimir/search/state_repository.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StateRepository
bool m_problem_or_domain_has_axioms;

FlatStateSet m_states;
StateBuilder m_state_builder;
StateImpl m_state_builder;

FlatBitset m_reached_fluent_atoms;
FlatBitset m_reached_derived_atoms;
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/faithful_abstraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ std::optional<FaithfulAbstraction> FaithfulAbstraction::create(Problem problem,
*factories,
problem,
initial_state,
initial_state.get_index(),
initial_state->get_index(),
options.mark_true_goal_literals,
*object_graph_pruning_strategy);
// std::cout << problem->get_filepath().value() << std::endl;
Expand All @@ -171,7 +171,7 @@ std::optional<FaithfulAbstraction> FaithfulAbstraction::create(Problem problem,

lifo_queue.pop_back();

if (state.literals_hold(problem->get_goal_condition<Fluent>()) && state.literals_hold(problem->get_goal_condition<Derived>()))
if (state->literals_hold(problem->get_goal_condition<Fluent>()) && state->literals_hold(problem->get_goal_condition<Derived>()))
{
abstract_goal_states.insert(abstract_state_index);
}
Expand All @@ -196,7 +196,7 @@ std::optional<FaithfulAbstraction> FaithfulAbstraction::create(Problem problem,
*factories,
problem,
successor_state,
successor_state.get_index(),
successor_state->get_index(),
options.mark_true_goal_literals,
*object_graph_pruning_strategy);

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/state_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ std::optional<StateSpace> StateSpace::create(Problem problem,
const auto vertex = lifo_queue.back();
const auto vertex_index = vertex.get_index();
lifo_queue.pop_back();
if (mimir::get_state(vertex).literals_hold(problem->get_goal_condition<Fluent>())
&& mimir::get_state(vertex).literals_hold(problem->get_goal_condition<Derived>()))
if (mimir::get_state(vertex)->literals_hold(problem->get_goal_condition<Fluent>())
&& mimir::get_state(vertex)->literals_hold(problem->get_goal_condition<Derived>()))
{
goal_vertex_indices.insert(vertex_index);
}
Expand Down
2 changes: 1 addition & 1 deletion src/graphs/color_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ template Color ProblemColorFunction::get_color(GroundAtom<Derived> atom, size_t
template<DynamicPredicateCategory P>
Color ProblemColorFunction::get_color(State state, GroundLiteral<P> literal, size_t pos, bool mark_true_goal_literal) const
{
bool is_satisfied_in_goal = state.literal_holds(literal);
bool is_satisfied_in_goal = state->literal_holds(literal);
return m_name_to_color.at(literal->get_atom()->get_predicate()->get_name() + ":g"
+ (mark_true_goal_literal ? (is_satisfied_in_goal ? ":true" : ":false") : "") + ":" + std::to_string(pos));
}
Expand Down
4 changes: 2 additions & 2 deletions src/graphs/object_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ static void add_ground_atoms_graph_structures(const ProblemColorFunction& color_
add_ground_atom_graph_structures(color_function, object_to_vertex_index, atom, out_digraph);
}
}
for (const auto& atom : pddl_factories.get_ground_atoms_from_indices<Fluent>(state.get_atoms<Fluent>()))
for (const auto& atom : pddl_factories.get_ground_atoms_from_indices<Fluent>(state->get_atoms<Fluent>()))
{
if (!pruning_strategy.prune(state_index, atom))
{
add_ground_atom_graph_structures(color_function, object_to_vertex_index, atom, out_digraph);
}
}
for (const auto& atom : pddl_factories.get_ground_atoms_from_indices<Derived>(state.get_atoms<Derived>()))
for (const auto& atom : pddl_factories.get_ground_atoms_from_indices<Derived>(state->get_atoms<Derived>()))
{
if (!pruning_strategy.prune(state_index, atom))
{
Expand Down
12 changes: 6 additions & 6 deletions src/graphs/object_graph_pruning_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ std::optional<ObjectGraphStaticSccPruningStrategy> ObjectGraphStaticSccPruningSt
auto group = partitioning.at(group_index);

// Reuse memory.
always_true_fluent_atoms = get_state(state_space->get_graph().get_vertices().at(group.front().second)).get_atoms<Fluent>();
always_true_derived_atoms = get_state(state_space->get_graph().get_vertices().at(group.front().second)).get_atoms<Derived>();
always_true_fluent_atoms = get_state(state_space->get_graph().get_vertices().at(group.front().second))->get_atoms<Fluent>();
always_true_derived_atoms = get_state(state_space->get_graph().get_vertices().at(group.front().second))->get_atoms<Derived>();
always_false_fluent_atoms.unset_all();
always_false_derived_atoms.unset_all();

Expand All @@ -260,10 +260,10 @@ std::optional<ObjectGraphStaticSccPruningStrategy> ObjectGraphStaticSccPruningSt
for (const auto& [group_index, state_index] : group)
{
const auto& state = get_state(state_space->get_graph().get_vertices().at(state_index));
always_true_fluent_atoms &= state.get_atoms<Fluent>();
always_true_derived_atoms &= state.get_atoms<Derived>();
always_false_fluent_atoms -= state.get_atoms<Fluent>();
always_false_derived_atoms -= state.get_atoms<Derived>();
always_true_fluent_atoms &= state->get_atoms<Fluent>();
always_true_derived_atoms &= state->get_atoms<Derived>();
always_false_fluent_atoms -= state->get_atoms<Fluent>();
always_false_derived_atoms -= state->get_atoms<Derived>();
}

/* 2. Initialize prunable objects to all objects.
Expand Down
2 changes: 1 addition & 1 deletion src/graphs/tuple_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::optional<TupleVertexIndexList> TupleGraph::compute_admissible_chain(const G
{
for (const auto& state : group)
{
if (state.get_atoms<Fluent>().is_superseteq(fluent_atom_bitset))
if (state->get_atoms<Fluent>().is_superseteq(fluent_atom_bitset))
{
states.push_back(state);
}
Expand Down
15 changes: 2 additions & 13 deletions src/search/action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ bool cista::storage::DerefStdEqualTo<mimir::FlatAction>::operator()(const mimir:
namespace mimir
{

/* FlatSimpleEffect*/

bool FlatSimpleEffect::operator==(const FlatSimpleEffect& other) const
{
if (this != &other)
{
return is_negated == other.is_negated && atom_index == other.atom_index;
}
return true;
}

/* StripsActionPreconditionBuilder */
StripsActionPreconditionBuilder::StripsActionPreconditionBuilder(FlatStripsActionPrecondition& builder) : m_builder(builder) {}

Expand Down Expand Up @@ -176,7 +165,7 @@ template bool StripsActionPrecondition::is_applicable<Derived>(const FlatBitset&
template<DynamicPredicateCategory P>
bool StripsActionPrecondition::is_applicable(State state) const
{
return is_applicable<P>(state.get_atoms<P>());
return is_applicable<P>(state->get_atoms<P>());
}

template bool StripsActionPrecondition::is_applicable<Fluent>(State state) const;
Expand Down Expand Up @@ -331,7 +320,7 @@ const FlatSimpleEffect& ConditionalEffect::get_simple_effect() const { return ci
template<DynamicPredicateCategory P>
bool ConditionalEffect::is_applicable(State state) const
{
const auto& state_atoms = state.get_atoms<P>();
const auto& state_atoms = state->get_atoms<P>();

return is_superseteq(state_atoms, get_positive_precondition<P>()) //
&& are_disjoint(state_atoms, get_negative_precondition<P>());
Expand Down
Loading

0 comments on commit 7cc3d71

Please sign in to comment.