From cc7f4992b7884ab6380c98774049f76d86cad76d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 30 Mar 2022 19:42:09 -0400 Subject: [PATCH] record number of assignments for each leaf --- gtsam/discrete/DecisionTree-inl.h | 36 +++++++++++++------ gtsam/discrete/tests/testDecisionTree.cpp | 43 +++++++++++++++++++++++ 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index b6e5482978..826e54b955 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -59,15 +59,23 @@ namespace gtsam { /** constant stored in this leaf */ Y constant_; - /** Constructor from constant */ - Leaf(const Y& constant) : - constant_(constant) {} + /** The number of assignments contained within this leaf + * Particularly useful when leaves have been pruned. + */ + size_t nrAssignments_; + + /// Constructor from constant + Leaf(const Y& constant, size_t nrAssignments = 1) + : constant_(constant), nrAssignments_(nrAssignments) {} /** return the constant */ const Y& constant() const { return constant_; } + /// Return the number of assignments contained within this leaf. + size_t nrAssignments() const { return nrAssignments_; } + /// Leaf-Leaf equality bool sameLeaf(const Leaf& q) const override { return constant_ == q.constant_; @@ -108,14 +116,14 @@ namespace gtsam { /** apply unary operator */ NodePtr apply(const Unary& op) const override { - NodePtr f(new Leaf(op(constant_))); + NodePtr f(new Leaf(op(constant_), nrAssignments_)); return f; } /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, const Assignment& choices) const override { - NodePtr f(new Leaf(op(choices, constant_))); + NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); return f; } @@ -130,7 +138,8 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + // fL op gL + NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_)); return h; } @@ -141,7 +150,7 @@ namespace gtsam { /** choose a branch, create new memory ! */ NodePtr choose(const L& label, size_t index) const override { - return NodePtr(new Leaf(constant())); + return NodePtr(new Leaf(constant(), nrAssignments())); } bool isLeaf() const override { return true; } @@ -178,9 +187,16 @@ namespace gtsam { if (f->allSame_) { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; - assert(f0->isLeaf()); + + size_t nrAssignments = 0; + for(auto branch: f->branches()) { + assert(branch->isLeaf()); + nrAssignments += + boost::dynamic_pointer_cast(branch)->nrAssignments(); + } NodePtr newLeaf( - new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + new Leaf(boost::dynamic_pointer_cast(f0)->constant(), + nrAssignments)); return newLeaf; } else #endif @@ -640,7 +656,7 @@ namespace gtsam { // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = boost::dynamic_pointer_cast(f)) { - return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); } // Check if Choice diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index f234905e33..14cf307a58 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) { StringContainerTree converted(stringIntTree, container_of_int); } +/* ************************************************************************** */ +// Test nrAssignments. +TEST(DecisionTree, NrAssignments) { + pair A("A", 2), B("B", 2), C("C", 2); + DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); + EXPECT(tree.root_->isLeaf()); + auto leaf = boost::dynamic_pointer_cast(tree.root_); + EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); + + DT tree2({C, B, A}, "1 1 1 2 3 4 5 5"); + /* The tree is + Choice(C) + 0 Choice(B) + 0 0 Leaf 1 + 0 1 Choice(A) + 0 1 0 Leaf 1 + 0 1 1 Leaf 2 + 1 Choice(B) + 1 0 Choice(A) + 1 0 0 Leaf 3 + 1 0 1 Leaf 4 + 1 1 Leaf 5 + */ + + auto root = boost::dynamic_pointer_cast(tree2.root_); + CHECK(root); + auto choice0 = boost::dynamic_pointer_cast(root->branches()[0]); + CHECK(choice0); + EXPECT(choice0->branches()[0]->isLeaf()); + auto choice00 = boost::dynamic_pointer_cast(choice0->branches()[0]); + CHECK(choice00); + EXPECT_LONGS_EQUAL(2, choice00->nrAssignments()); + + auto choice1 = boost::dynamic_pointer_cast(root->branches()[1]); + CHECK(choice1); + auto choice10 = boost::dynamic_pointer_cast(choice1->branches()[0]); + CHECK(choice10); + auto choice11 = boost::dynamic_pointer_cast(choice1->branches()[1]); + CHECK(choice11); + EXPECT(choice11->isLeaf()); + EXPECT_LONGS_EQUAL(2, choice11->nrAssignments()); +} + /* ************************************************************************** */ // Test visit. TEST(DecisionTree, visit) {