-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
cpp_template_pybind_test.cc
137 lines (110 loc) · 4.14 KB
/
cpp_template_pybind_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include "drake/bindings/pydrake/common/cpp_template_pybind.h"
// @file
// Tests the public interfaces in `cpp_template.py` and
// `cpp_template_pybind.h`.
#include <string>
#include <utility>
#include <vector>
#include <gtest/gtest.h>
#include "pybind11/embed.h"
#include "pybind11/eval.h"
#include "pybind11/pybind11.h"
#include "drake/bindings/pydrake/test/test_util_pybind.h"
#include "drake/common/nice_type_name.h"
#include "drake/common/test_utilities/expect_throws_message.h"
using std::string;
using std::vector;
namespace drake {
namespace pydrake {
namespace {
using test::SynchronizeGlobalsForPython3;
template <typename... Ts>
struct SimpleTemplate {
vector<string> GetNames() { return {NiceTypeName::Get<Ts>()...}; }
};
template <typename... Ts>
py::object BindSimpleTemplate(py::module m) {
using Class = SimpleTemplate<Ts...>;
py::class_<Class> py_class(m, TemporaryClassName<Class>().c_str());
py_class // BR
.def(py::init<>())
.def("GetNames", &Class::GetNames);
AddTemplateClass(m, "SimpleTemplate", py_class, GetPyParam<Ts...>());
// We use move here because the type of py_class differs from our declared
// return type.
return std::move(py_class);
}
template <typename T>
void CheckValue(const string& expr, const T& expected) {
EXPECT_EQ(py::eval(expr).cast<T>(), expected);
}
GTEST_TEST(CppTemplateTest, TemplateClass) {
py::module m("__main__");
m.attr("DefaultInst") = BindSimpleTemplate<int>(m);
BindSimpleTemplate<int, double>(m);
const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
SynchronizeGlobalsForPython3(m);
CheckValue("DefaultInst().GetNames()", expected_1);
CheckValue("SimpleTemplate[int]().GetNames()", expected_1);
CheckValue("SimpleTemplate[int, float]().GetNames()", expected_2);
m.def("simple_func", [](const SimpleTemplate<int>&) {});
SynchronizeGlobalsForPython3(m);
// Check error message if a function is called with the incorrect arguments.
// N.B. We use `[^\0]` because C++ regex does not have an equivalent of
// Python re's DOTALL flag. `[\s\S]` *should* work, but Apple LLVM 10.0.0
// does not work with it.
DRAKE_EXPECT_THROWS_MESSAGE(py::eval("simple_func('incorrect value')"),
std::runtime_error,
R"([^\0]*incompatible function arguments[^\0]*\(arg0: __main__\.SimpleTemplate\[int\]\)[^\0]*)"); // NOLINT
}
template <typename... Ts>
vector<string> SimpleFunction() {
return {NiceTypeName::Get<Ts>()...};
}
GTEST_TEST(CppTemplateTest, TemplateFunction) {
py::module m("__main__");
AddTemplateFunction(m, "SimpleFunction", // BR
&SimpleFunction<int>, GetPyParam<int>());
AddTemplateFunction(m, "SimpleFunction", // BR
&SimpleFunction<int, double>, GetPyParam<int, double>());
const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
SynchronizeGlobalsForPython3(m);
CheckValue("SimpleFunction[int]()", expected_1);
CheckValue("SimpleFunction[int, float]()", expected_2);
}
struct SimpleType {
template <typename... Ts>
vector<string> SimpleMethod() {
return {NiceTypeName::Get<Ts>()...};
}
};
GTEST_TEST(CppTemplateTest, TemplateMethod) {
py::module m("__main__");
py::class_<SimpleType> py_class(m, "SimpleType");
py_class // BR
.def(py::init<>());
AddTemplateMethod(py_class, "SimpleMethod", &SimpleType::SimpleMethod<int>,
GetPyParam<int>());
AddTemplateMethod(py_class, "SimpleMethod",
&SimpleType::SimpleMethod<int, double>, GetPyParam<int, double>());
const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
SynchronizeGlobalsForPython3(m);
CheckValue("SimpleType().SimpleMethod[int]()", expected_1);
CheckValue("SimpleType().SimpleMethod[int, float]()", expected_2);
}
int main(int argc, char** argv) {
// Reconstructing `scoped_interpreter` multiple times (e.g. via `SetUp()`)
// while *also* importing `numpy` wreaks havoc.
py::scoped_interpreter guard;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
} // namespace
} // namespace pydrake
} // namespace drake
int main(int argc, char** argv) {
return drake::pydrake::main(argc, argv);
}