Skip to content

Commit

Permalink
Merge pull request #33 from oberbichler/feature/f-d-dd
Browse files Browse the repository at this point in the history
Add hj::f/d/dd
  • Loading branch information
oberbichler authored Mar 24, 2021
2 parents ea2dc5e + 45bbd42 commit 24b648c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/python_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/eval.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/stl_bind.h>
Expand Down Expand Up @@ -48,10 +49,6 @@ void register_ddscalar(pybind11::module& m, const std::string& name)
.def("__repr__", &Type::to_string)
.def("abs", &Type::abs)
.def("eval", &Type::eval, "d"_a)
.def("h", py::overload_cast<hj::index, hj::index>(&Type::h), "row"_a, "col"_a)
.def("set_h", py::overload_cast<hj::index, hj::index, TScalar>(&Type::set_h), "row"_a, "col"_a, "value"_a)
.def("hm", py::overload_cast<std::string>(&Type::hm, py::const_), "mode"_a="full")
.def("set_hm", &Type::set_hm, "value"_a)
// methods: arithmetic operations
.def("reciprocal", &Type::reciprocal)
.def("sqrt", &Type::sqrt)
Expand Down Expand Up @@ -146,7 +143,13 @@ void register_ddscalar(pybind11::module& m, const std::string& name)
// FIXME: add from_gradient
} else {
py_class
.def(py::init(&Type::from_arrays), "f"_a, "g"_a, "hm"_a);
// constructors
.def(py::init(&Type::from_arrays), "f"_a, "g"_a, "hm"_a)
// methods
.def("h", py::overload_cast<hj::index, hj::index>(&Type::h), "row"_a, "col"_a)
.def("set_h", py::overload_cast<hj::index, hj::index, TScalar>(&Type::set_h), "row"_a, "col"_a, "value"_a)
.def("hm", py::overload_cast<std::string>(&Type::hm, py::const_), "mode"_a="full")
.def("set_hm", &Type::set_hm, "value"_a);
}

if constexpr(Type::is_dynamic()) {
Expand All @@ -168,6 +171,8 @@ void register_ddscalar(pybind11::module& m, const std::string& name)

PYBIND11_MODULE(hyperjet, m)
{
namespace py = pybind11;

m.doc() = "HyperJet by Thomas Oberbichler";
m.attr("__author__") = "Thomas Oberbichler";
m.attr("__copyright__") = "Copyright (c) 2019-2021, Thomas Oberbichler";
Expand Down Expand Up @@ -210,4 +215,15 @@ PYBIND11_MODULE(hyperjet, m)
register_ddscalar<2, double, 14>(m, "DD14Scalar");
register_ddscalar<2, double, 15>(m, "DD15Scalar");
register_ddscalar<2, double, 16>(m, "DD16Scalar");

// utilities
{
py::object numpy = py::module::import("numpy");
auto global = py::dict();
global["np"] = numpy;

m.attr("f") = py::eval("np.vectorize(lambda v: v.f if hasattr(v, 'f') else v)", global);
m.attr("d") = py::eval("np.vectorize(lambda v: v.g if hasattr(v, 'g') else np.zeros((0)), signature='()->(n)')", global);
m.attr("dd") = py::eval("np.vectorize(lambda v: v.hm() if hasattr(v, 'hm') else np.zeros((0, 0)), signature='()->(n,m)')", global);
}
}
93 changes: 93 additions & 0 deletions tests/test_DDScalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,96 @@ def test_log2(ctx):
def test_log10(ctx):
r = np.log10(ctx.u1)
ctx.check(r, [0.255272505103306, 0.289529654602168, -0.0868588963806504, -0.0965098848673893, 0, 0.0173717792761301])


@pytest.mark.parametrize('ctx', **test_data)
def test_f(ctx):
u = [ctx.u1, ctx.u2]

f = hj.f(u)

assert_equal(f[0], u[0].f)
assert_equal(f[1], u[1].f)

v = np.dot(u, u)

f = hj.f(v)

assert_equal(f, v.f)


@pytest.mark.parametrize('ctx', **test_data)
def test_d(ctx):
u = [ctx.u1, ctx.u2]

d = hj.d(u)

assert_equal(d[0], u[0].g)
assert_equal(d[1], u[1].g)

v = np.dot(u, u)

d = hj.d(v)

assert_equal(d, v.g)


@pytest.mark.parametrize('ctx', **test_data)
def test_dd(ctx):
if ctx.dtype.order < 2:
return

u = [ctx.u1, ctx.u2]

dd = hj.dd(u)

assert_equal(dd[0], u[0].hm())
assert_equal(dd[1], u[1].hm())

v = np.dot(u, u)

dd = hj.dd(v)

assert_equal(dd, v.hm())


def test_f_of_scalar():
assert_equal(hj.f(1), 1)

assert_equal(hj.f([1, 2, 3, 4]), [1, 2, 3, 4])

assert_equal(hj.f([[1, 2, 3, 4]]), [[1, 2, 3, 4]])

assert_equal(hj.f([[1, 2], [3, 4]]), [[1, 2], [3, 4]])

assert_equal(hj.f(np.array([1, 2, 3, 4])), [1, 2, 3, 4])

assert_equal(hj.f(np.array([[1, 2], [3, 4]])), [[1, 2], [3, 4]])


def test_d_of_scalar():
assert_equal(hj.d(1), np.empty(0))

assert_equal(hj.d([1, 2, 3, 4]), np.empty((4, 0)))

assert_equal(hj.d([[1, 2, 3, 4]]), np.empty((1, 4, 0)))

assert_equal(hj.d([[1, 2], [3, 4]]), np.empty((2, 2, 0)))

assert_equal(hj.d(np.array([1, 2, 3, 4])), np.empty((4, 0)))

assert_equal(hj.d(np.array([[1, 2], [3, 4]])), np.empty((2, 2, 0)))


def test_dd_of_scalar():
assert_equal(hj.dd(1), 0)

assert_equal(hj.dd([1, 2, 3, 4]), np.empty((4, 0, 0)))

assert_equal(hj.dd([[1, 2, 3, 4]]), np.empty((1, 4, 0, 0)))

assert_equal(hj.dd([[1, 2], [3, 4]]), np.empty((2, 2, 0, 0)))

assert_equal(hj.dd(np.array([1, 2, 3, 4])), np.empty((4, 0, 0)))

assert_equal(hj.dd(np.array([[1, 2], [3, 4]])), np.empty((2, 2, 0, 0)))

0 comments on commit 24b648c

Please sign in to comment.