From 6374488519683d324cb5bf524971c7764d8bd89e Mon Sep 17 00:00:00 2001 From: Henry Fredrick Schreiner Date: Fri, 29 Sep 2017 09:16:23 -0400 Subject: [PATCH] Adding string constructor for enum --- docs/classes.rst | 15 +++++++++++++++ include/pybind11/pybind11.h | 8 ++++++++ tests/test_enum.cpp | 16 ++++++++++++++++ tests/test_enum.py | 21 +++++++++++++++++++++ 4 files changed, 60 insertions(+) diff --git a/docs/classes.rst b/docs/classes.rst index 75a8fb2c87..3050fb84fc 100644 --- a/docs/classes.rst +++ b/docs/classes.rst @@ -506,6 +506,21 @@ The ``name`` property returns the name of the enum value as a unicode string. >>> pet_type.name 'Cat' +You can also access the enumeration using a string using the enum's constructor, +such as ``Pet('Cat')``. This makes it possible to automatically convert a string +to an enumeration in an API if the enumeration is marked implicitly convertible +from a string, with a line such as: + +.. code-block:: cpp + + py::implicitly_convertible(); + +Now, in Python, the following code will also correctly construct a cat: + +.. code-block:: pycon + + >>> p = Pet('Lucy', 'Cat') + .. note:: When the special tag ``py::arithmetic()`` is specified to the ``enum_`` diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 6947d440f7..ef92acd0e4 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1411,6 +1411,14 @@ template class enum_ : public class_ { return m; }, return_value_policy::copy); def(init([](Scalar i) { return static_cast(i); })); + def(init([name, m_entries_ptr](std::string value) -> Type { + pybind11::dict values = reinterpret_borrow(m_entries_ptr); + pybind11::str key = pybind11::str(value); + if (values.contains(key)) + return pybind11::cast(values[key]); + else + throw value_error("\"" + value + "\" is not a valid value for enum type " + name); + })); def("__int__", [](Type value) { return (Scalar) value; }); #if PY_MAJOR_VERSION < 3 def("__long__", [](Type value) { return (Scalar) value; }); diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 4cd14a96a6..d2c54b53a0 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -20,6 +20,22 @@ TEST_SUBMODULE(enums, m) { .value("ETwo", ETwo, "Docstring for ETwo") .export_values(); + // test_conversion_enum + enum class ConversionEnum { + Convert1 = 1, + Convert2 + }; + + py::enum_(m, "ConversionEnum", py::arithmetic()) + .value("Convert1", ConversionEnum::Convert1) + .value("Convert2", ConversionEnum::Convert2) + ; + py::implicitly_convertible(); + + m.def("test_conversion_enum", [](ConversionEnum z) { + return "ConversionEnum::" + std::string(z == ConversionEnum::Convert1 ? "Convert1" : "Convert2"); + }); + // test_scoped_enum enum class ScopedEnum { Two = 2, diff --git a/tests/test_enum.py b/tests/test_enum.py index c2c272a25c..dfad036baa 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -54,6 +54,7 @@ def test_unscoped_enum(): assert int(m.UnscopedEnum.ETwo) == 2 assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo" + assert str(m.UnscopedEnum("ETwo")) == "UnscopedEnum.ETwo" # order assert m.UnscopedEnum.EOne < m.UnscopedEnum.ETwo @@ -70,8 +71,28 @@ def test_unscoped_enum(): assert not (2 < m.UnscopedEnum.EOne) +def test_converstion_enum(): + assert m.test_conversion_enum(m.ConversionEnum.Convert1) == "ConversionEnum::Convert1" + assert m.test_conversion_enum(m.ConversionEnum("Convert1")) == "ConversionEnum::Convert1" + assert m.test_conversion_enum("Convert1") == "ConversionEnum::Convert1" + + +def test_conversion_enum_raises(): + with pytest.raises(ValueError) as excinfo: + m.ConversionEnum("Convert0") + assert str(excinfo.value) == "\"Convert0\" is not a valid value for enum type ConversionEnum" + + +def test_conversion_enum_raises_implicit(): + with pytest.raises(ValueError) as excinfo: + m.test_conversion_enum("Convert0") + assert str(excinfo.value) == "\"Convert0\" is not a valid value for enum type ConversionEnum" + + def test_scoped_enum(): assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three" + with pytest.raises(TypeError): + m.test_scoped_enum("Three") z = m.ScopedEnum.Two assert m.test_scoped_enum(z) == "ScopedEnum::Two"