Skip to content

Commit

Permalink
Adding string constructor for enum
Browse files Browse the repository at this point in the history
  • Loading branch information
henryiii committed Sep 29, 2017
1 parent 64a99b9 commit 287aa03
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
14 changes: 13 additions & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ template <typename Type> class enum_ : public class_<Type> {

template <typename... Extra>
enum_(const handle &scope, const char *name, const Extra&... extra)
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope), m_name(name) {

constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;

Expand All @@ -1377,6 +1377,15 @@ template <typename Type> class enum_ : public class_<Type> {
return m;
}, return_value_policy::copy);
def(init([](Scalar i) { return static_cast<Type>(i); }));
def(init([this, m_entries_ptr](std::string value) -> Type {
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
std::string key = cast<str>(kv.first);
if(value == key || key == m_name + "::" + value) {
return cast<Type>(kv.second);
}
}
throw value_error("\"" + value + "\" is not a valid value for enum type " + m_name);
}));
def("__int__", [](Type value) { return (Scalar) value; });
#if PY_MAJOR_VERSION < 3
def("__long__", [](Type value) { return (Scalar) value; });
Expand Down Expand Up @@ -1436,8 +1445,10 @@ template <typename Type> class enum_ : public class_<Type> {
private:
dict m_entries;
handle m_parent;
std::string m_name;
};


NAMESPACE_BEGIN(detail)


Expand Down Expand Up @@ -1695,6 +1706,7 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
}
NAMESPACE_END(detail)


template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
void print(Args &&...args) {
auto c = detail::collect_arguments<policy>(std::forward<Args>(args)...);
Expand Down
16 changes: 16 additions & 0 deletions tests/test_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ TEST_SUBMODULE(enums, m) {
.value("ETwo", ETwo)
.export_values();

// test_conversion_enum
enum class ConversionEnum {
Convert1 = 1,
Convert2
};

py::enum_<ConversionEnum>(m, "ConversionEnum", py::arithmetic())
.value("Convert1", ConversionEnum::Convert1)
.value("Convert2", ConversionEnum::Convert2)
;
py::implicitly_convertible<py::str, ConversionEnum>();

m.def("test_conversion_enum", [](ConversionEnum z) {
return "ConversionEnum::" + std::string(z == ConversionEnum::Convert1 ? "Convert1" : "Convert2");
});

// test_scoped_enum
enum class ScopedEnum {
Two = 2,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_unscoped_enum():

assert int(m.UnscopedEnum.ETwo) == 2
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
assert str(m.UnscopedEnum("ETwo")) == "UnscopedEnum.ETwo"

# order
assert m.UnscopedEnum.EOne < m.UnscopedEnum.ETwo
Expand All @@ -40,9 +41,17 @@ def test_unscoped_enum():
assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne)
assert not (2 < m.UnscopedEnum.EOne)

def test_converstion_enum():
assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1"
assert m.test_conversion_enum(m.ConversionEnum("Convert1")) == "ConversionEnum::Convert1"
assert m.test_conversion_enum("Convert1") == "ConversionEnum::Convert1"
assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1"


def test_scoped_enum():
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
with pytest.raises(TypeError):
m.test_scoped_enum("Three")
z = m.ScopedEnum.Two
assert m.test_scoped_enum(z) == "ScopedEnum::Two"

Expand Down

0 comments on commit 287aa03

Please sign in to comment.