Skip to content

Commit

Permalink
Merge pull request #1154 from borglab/decisiontree/nrAssignments
Browse files Browse the repository at this point in the history
Record number of assignments for each leaf
  • Loading branch information
varunagrawal authored Apr 2, 2022
2 parents b2b878e + cc7f499 commit 99c01c4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
36 changes: 26 additions & 10 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,23 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;

/** Constructor from constant */
Leaf(const Y& constant) :
constant_(constant) {}
/** The number of assignments contained within this leaf
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}

/** return the constant */
const Y& constant() const {
return constant_;
}

/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }

/// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_;
Expand Down Expand Up @@ -108,14 +116,14 @@ namespace gtsam {

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_)));
NodePtr f(new Leaf(op(constant_), nrAssignments_));
return f;
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_)));
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
return f;
}

Expand All @@ -130,7 +138,8 @@ namespace gtsam {

// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
return h;
}

Expand All @@ -141,7 +150,7 @@ namespace gtsam {

/** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant()));
return NodePtr(new Leaf(constant(), nrAssignments()));
}

bool isLeaf() const override { return true; }
Expand Down Expand Up @@ -178,9 +187,16 @@ namespace gtsam {
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());

size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
}
NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
#endif
Expand Down Expand Up @@ -640,7 +656,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
}

// Check if Choice
Expand Down
43 changes: 43 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int);
}

/* ************************************************************************** */
// Test nrAssignments.
TEST(DecisionTree, NrAssignments) {
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf());
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());

DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
/* The tree is
Choice(C)
0 Choice(B)
0 0 Leaf 1
0 1 Choice(A)
0 1 0 Leaf 1
0 1 1 Leaf 2
1 Choice(B)
1 0 Choice(A)
1 0 0 Leaf 3
1 0 1 Leaf 4
1 1 Leaf 5
*/

auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
CHECK(choice0);
EXPECT(choice0->branches()[0]->isLeaf());
auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
CHECK(choice00);
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());

auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
CHECK(choice1);
auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
CHECK(choice10);
auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
CHECK(choice11);
EXPECT(choice11->isLeaf());
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
}

/* ************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {
Expand Down

0 comments on commit 99c01c4

Please sign in to comment.