Skip to content

Commit

Permalink
Merge pull request #19 from varunagrawal/feature/hybrid-wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Feb 16, 2022
2 parents f3a95a3 + 36f06fe commit 332f3f1
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 14 deletions.
15 changes: 8 additions & 7 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file GaussianMixture.h
* @brief Discrete-continuous conditional density
* @author Frank Dellaert
* @author Fan Jiang
* @date December 2021
*/

Expand All @@ -21,8 +22,8 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/DCGaussianMixtureFactor.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>

namespace gtsam {

Expand Down Expand Up @@ -63,21 +64,21 @@ class GaussianMixture
* GaussianConditionalMixture(const Conditionals& conditionals,
* const DiscreteKeys& discreteParentKeys)
*/
GaussianMixture(size_t nrFrontals,
const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Conditionals &conditionals);
GaussianMixture(size_t nrFrontals, const KeyVector& continuousKeys,
const DiscreteKeys& discreteKeys,
const Conditionals& conditionals);

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(const DiscreteValues& discreteVals) const;
GaussianConditional::shared_ptr operator()(
const DiscreteValues& discreteVals) const;

/// Returns the total number of continuous components
size_t nrComponents() {
size_t total = 0;
factors_.visit([&total](const GaussianFactor::shared_ptr &node) {
factors_.visit([&total](const GaussianFactor::shared_ptr& node) {
if (node) total += 1;
});
return total;
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using SharedFactor = boost::shared_ptr<Factor>;
* NonlinearFactorGraph, GaussianFactorGraph.
*/
template <typename FG>
class HybridFactorGraph : protected FactorGraph<Factor> {
class HybridFactorGraph : public FactorGraph<Factor> {
public:
using shared_ptr = boost::shared_ptr<HybridFactorGraph>;
using Base = FactorGraph<Factor>;
Expand Down
5 changes: 5 additions & 0 deletions gtsam/hybrid/NonlinearHybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ void NonlinearHybridFactorGraph::print(
factorGraph_.print("NonlinearFactorGraph", keyFormatter);
}

bool NonlinearHybridFactorGraph::equals(const NonlinearHybridFactorGraph& other,
double tol) const {
return Base::equals(other, tol);
}

GaussianHybridFactorGraph NonlinearHybridFactorGraph::linearize(
const Values& continuousValues) const {
// linearize the continuous factors
Expand Down
12 changes: 6 additions & 6 deletions gtsam/hybrid/NonlinearHybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class GTSAM_EXPORT NonlinearHybridFactorGraph
const std::string& str = "NonlinearHybridFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

/**
* @return true if all internal graphs of `this` are equal to those of
* `other`
*/
bool equals(const NonlinearHybridFactorGraph& other, double tol = 1e-9) const;

/**
* Utility for retrieving the internal nonlinear factor graph
* @return the member variable nonlinearGraph_
Expand All @@ -122,12 +128,6 @@ class GTSAM_EXPORT NonlinearHybridFactorGraph
*/
GaussianHybridFactorGraph linearize(const Values& continuousValues) const;

/**
* @return true if all internal graphs of `this` are equal to those of
* `other`
*/
bool equals(const NonlinearHybridFactorGraph& other, double tol = 1e-9) const;

/// The total number of factors in the nonlinear factor graph.
size_t nrNonlinearFactors() const { return factorGraph_.size(); }

Expand Down
119 changes: 119 additions & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//*************************************************************************
// hybrid
//*************************************************************************

namespace gtsam {

// #include <gtsam/inference/Key.h>
// class gtsam::KeyVector;

#include <gtsam/hybrid/DCFactor.h>
#include <gtsam/slam/BetweenFactor.h>

virtual class DCFactor {};

#include <gtsam/hybrid/DCMixtureFactor.h>
template <T>
virtual class DCMixtureFactor : gtsam::DCFactor {
DCMixtureFactor();
DCMixtureFactor(const gtsam::KeyVector& keys,
const gtsam::DiscreteKeys& discreteKeys,
const std::vector<T*>& factors, bool normalized = false);
};

typedef gtsam::DCMixtureFactor<gtsam::BetweenFactor<double>>
DCMixtureFactorBetweenFactorDouble;

#include <gtsam/hybrid/DCGaussianMixtureFactor.h>

virtual class DCGaussianMixtureFactor : gtsam::DCFactor {
DCGaussianMixtureFactor();
DCGaussianMixtureFactor(
const gtsam::KeyVector& continuousKeys,
const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DCGaussianMixtureFactor::Factors& factors);
};

#include <gtsam/hybrid/GaussianMixture.h>

virtual class GaussianMixture : gtsam::DCGaussianMixtureFactor {
GaussianMixture();
GaussianMixture(size_t nrFrontals, const gtsam::KeyVector& continuousKeys,
const gtsam::DiscreteKeys& discreteKeys,
const gtsam::GaussianMixture::Conditionals& conditionals);
};

#include <gtsam/hybrid/DCFactorGraph.h>

class DCFactorGraph {
DCFactorGraph();
gtsam::DiscreteKeys discreteKeys() const;
};

#include <gtsam/hybrid/HybridFactorGraph.h>

template <FG>
virtual class HybridFactorGraph {
HybridFactorGraph();
HybridFactorGraph(const FG& factorGraph,
const gtsam::DiscreteFactorGraph& discreteGraph,
const gtsam::DCFactorGraph& dcGraph);

bool equals(const gtsam::HybridFactorGraph<FG>& other,
double tol = 1e-9) const;
void print(const std::string& str = "HybridFactorGraph",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

const gtsam::DiscreteFactorGraph& discreteGraph() const;
const gtsam::DCFactorGraph& dcGraph() const;

gtsam::DiscreteKeys discreteKeys() const;
};

typedef gtsam::HybridFactorGraph<gtsam::NonlinearFactorGraph>
HybridFactorGraphNonlinear;

#include <gtsam/hybrid/NonlinearHybridFactorGraph.h>

virtual class NonlinearHybridFactorGraph {
NonlinearHybridFactorGraph();
NonlinearHybridFactorGraph(const gtsam::NonlinearFactorGraph& nonlinearGraph,
const gtsam::DiscreteFactorGraph& discreteGraph,
const gtsam::DCFactorGraph& dcGraph);

const gtsam::NonlinearFactorGraph& nonlinearGraph() const;

gtsam::GaussianHybridFactorGraph linearize(
const gtsam::Values& continuousValues) const;

size_t size() const;
bool equals(const gtsam::NonlinearHybridFactorGraph& other,
double tol = 1e-9) const;
void print(const std::string& str = "NonlinearHybridFactorGraph",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/GaussianHybridFactorGraph.h>

class GaussianHybridFactorGraph {
GaussianHybridFactorGraph();
GaussianHybridFactorGraph(const gtsam::GaussianFactorGraph& gaussianGraph,
const gtsam::DiscreteFactorGraph& discreteGraph,
const gtsam::DCFactorGraph& dcGraph);

size_t size() const;
void print(const std::string& str = "GaussianHybridFactorGraph",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/IncrementalHybrid.h>

class IncrementalHybrid {
void update(gtsam::GaussianHybridFactorGraph graph,
const gtsam::Ordering& ordering);
};

} // namespace gtsam
1 change: 1 addition & 0 deletions matlab/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ set(interface_files
${GTSAM_SOURCE_DIR}/gtsam/linear/linear.i
${GTSAM_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i
${GTSAM_SOURCE_DIR}/gtsam/symbolic/symbolic.i
${GTSAM_SOURCE_DIR}/gtsam/hybrid/hybrid.i
${GTSAM_SOURCE_DIR}/gtsam/sam/sam.i
${GTSAM_SOURCE_DIR}/gtsam/slam/slam.i
${GTSAM_SOURCE_DIR}/gtsam/sfm/sfm.i
Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i
${PROJECT_SOURCE_DIR}/gtsam/symbolic/symbolic.i
${PROJECT_SOURCE_DIR}/gtsam/hybrid/hybrid.i
${PROJECT_SOURCE_DIR}/gtsam/sam/sam.i
${PROJECT_SOURCE_DIR}/gtsam/slam/slam.i
${PROJECT_SOURCE_DIR}/gtsam/sfm/sfm.i
Expand Down
12 changes: 12 additions & 0 deletions python/gtsam/preamble/hybrid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
* automatic STL binding, such that the raw objects can be accessed in Python.
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
12 changes: 12 additions & 0 deletions python/gtsam/specializations/hybrid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `py::bind_vector` and similar machinery gives the std container a Python-like
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation.
*/
13 changes: 13 additions & 0 deletions python/gtsam/tests/test_Hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import gtsam
import numpy as np
from gtsam import GaussianHybridFactorGraph
from gtsam.utils.test_case import GtsamTestCase


class TestHybridElimination(GtsamTestCase):
def setUp(self) -> None:
self.ghfg = GaussianHybridFactorGraph()

def test_elimination(self):
# Check if constructed correctly
self.assertIsInstance(self.ghfg, GaussianHybridFactorGraph)

0 comments on commit 332f3f1

Please sign in to comment.