From 80527832d63b2e742534e17455e1c803e0eb2455 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Sun, 29 Oct 2017 16:32:26 -0300 Subject: [PATCH] Detect and fail if using mismatched holders This adds a check when registering a class or a function with a holder return that the same wrapped type hasn't been previously seen using a different holder type. This fixes #1138 by detecting the failure; currently attempting to use two different holder types (e.g. a unique_ptr and shared_ptr) in difference places can segfault because we don't have any type safety on the holder instances. --- include/pybind11/cast.h | 27 +++++++++++++++++++++++++++ include/pybind11/detail/internals.h | 3 ++- include/pybind11/pybind11.h | 6 ++++++ tests/test_smart_ptr.cpp | 17 ++++++++++++++++- tests/test_smart_ptr.py | 11 +++++++++++ 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 715ec932d9..0789a5c8ec 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1493,6 +1493,33 @@ template struct is_holder_type : template struct is_holder_type> : std::true_type {}; +template using is_holder = any_of< + is_template_base_of>, + is_template_base_of>>; + +template +void check_for_holder_mismatch(enable_if_t::value, int> = 0) {} +template +void check_for_holder_mismatch(enable_if_t::value, int> = 0) { + using iholder = intrinsic_t; + using base_type = decltype(*holder_helper::get(std::declval())); + auto &holder_typeinfo = typeid(iholder); + auto ins = get_internals().holders_seen.emplace(typeid(base_type), &holder_typeinfo); + + auto debug = type_id(); + if (!ins.second && !same_type(*ins.first->second, holder_typeinfo)) { +#ifdef NDEBUG + pybind11_fail("Mismatched holders detected (compile in debug mode for details)"); +#else + std::string seen_holder_name(ins.first->second->name()); + detail::clean_type_id(seen_holder_name); + pybind11_fail("Mismatched holders detected: " + " attempting to use holder type " + type_id() + ", but " + type_id() + + " was already seen using holder type " + seen_holder_name); +#endif + } +} + template struct handle_type_name { static constexpr auto name = _(); }; template <> struct handle_type_name { static constexpr auto name = _(PYBIND11_BYTES_NAME); }; template <> struct handle_type_name { static constexpr auto name = _("*args"); }; diff --git a/include/pybind11/detail/internals.h b/include/pybind11/detail/internals.h index 213cbaeb21..4d7ef6d8ee 100644 --- a/include/pybind11/detail/internals.h +++ b/include/pybind11/detail/internals.h @@ -68,6 +68,7 @@ struct internals { type_map registered_types_cpp; // std::type_index -> pybind11's type information std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) std::unordered_multimap registered_instances; // void * -> instance* + type_map holders_seen; // type -> seen holder type (to detect holder conflicts) std::unordered_set, overload_hash> inactive_overload_cache; type_map> direct_conversions; std::unordered_map> patients; @@ -111,7 +112,7 @@ struct type_info { }; /// Tracks the `internals` and `type_info` ABI version independent of the main library version -#define PYBIND11_INTERNALS_VERSION 1 +#define PYBIND11_INTERNALS_VERSION 2 #if defined(WITH_THREAD) # define PYBIND11_INTERNALS_KIND "" diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 90b8ebcc63..13ac9ded68 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -127,6 +127,9 @@ class cpp_function : public function { static_assert(detail::expected_num_args(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs), "The number of argument annotations does not match the number of function arguments"); + // Fail if we've previously seen a different holder around the held type + detail::check_for_holder_mismatch(); + /* Dispatch code which converts function arguments and performs the actual function call */ rec->impl = [](detail::function_call &call) -> handle { cast_in args_converter; @@ -1045,6 +1048,9 @@ class class_ : public detail::generic_type { none_of...>::value), // no multiple_inheritance attr "Error: multiple inheritance bases must be specified via class_ template options"); + // Fail if we've previously seen a different holder around the type + detail::check_for_holder_mismatch(); + type_record record; record.scope = scope; record.name = name; diff --git a/tests/test_smart_ptr.cpp b/tests/test_smart_ptr.cpp index dccb1e9be5..6b8a4c17c4 100644 --- a/tests/test_smart_ptr.cpp +++ b/tests/test_smart_ptr.cpp @@ -40,7 +40,7 @@ template class huge_unique_ptr { uint64_t padding[10]; public: huge_unique_ptr(T *p) : ptr(p) {}; - T *get() { return ptr.get(); } + T *get() const { return ptr.get(); } }; PYBIND11_DECLARE_HOLDER_TYPE(T, huge_unique_ptr); @@ -267,4 +267,19 @@ TEST_SUBMODULE(smart_ptr, m) { list.append(py::cast(e)); return list; }); + + // test_holder_mismatch + // Tests the detection of trying to use mismatched holder types around the same instance type + struct HeldByShared {}; + struct HeldByUnique {}; + py::class_>(m, "HeldByShared"); + m.def("register_mismatch_return", [](py::module m) { + // Fails: the class was already registered with a shared_ptr holder + m.def("bad1", []() { return std::unique_ptr(new HeldByShared()); }); + }); + m.def("return_shared", []() { return std::make_shared(); }); + m.def("register_mismatch_class", [](py::module m) { + // Fails: `return_shared2' already returned this via shared_ptr holder + py::class_(m, "HeldByUnique"); + }); } diff --git a/tests/test_smart_ptr.py b/tests/test_smart_ptr.py index 4dfe0036fc..2452e2f811 100644 --- a/tests/test_smart_ptr.py +++ b/tests/test_smart_ptr.py @@ -218,3 +218,14 @@ def test_shared_ptr_gc(): pytest.gc_collect() for i, v in enumerate(el.get()): assert i == v.value() + + +def test_holder_mismatch(): + """#1138: segfault if mixing holder types""" + with pytest.raises(RuntimeError) as excinfo: + m.register_mismatch_return(m) + assert "Mismatched holders detected" in str(excinfo) + + with pytest.raises(RuntimeError) as excinfo: + m.register_mismatch_class(m) + assert "Mismatched holders detected" in str(excinfo)