Skip to content

Commit

Permalink
Lock function arguments at compile time (#720)
Browse files Browse the repository at this point in the history
This commit refactors argument the locking locking so that it occurs at
compile-time without imposing runtime overheads. The change applies to
free-threaded extensions.

Behavior differences compared to the prior approach:

- it is no longer possible to do ``nb::arg().lock(false)`` or
  ``.lock(runtime_determined_value)``

- we no longer prohibit locking self in ``__init__``; changing this
  would also require restoring ``cast_flags::lock``, and it's not clear
  that the benefit outweighs the complexity.
  • Loading branch information
oremanj authored and wjakob committed Sep 20, 2024
1 parent 47d04ad commit 6d60ed7
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 116 deletions.
12 changes: 7 additions & 5 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,9 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,

.. cpp:function:: template <typename T> arg_v operator=(T &&value) const

Assign a default value to the argument.
Return an argument annotation that is like this one but also assigns a
default value to the argument. The default will be converted into a Python
object immediately, so its bindings must have already been defined.

.. cpp:function:: arg &none(bool value = true)

Expand All @@ -1642,11 +1644,11 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,
explain it in docstrings and stubs (``str(value)``) does not produce
acceptable output.

.. cpp:function:: arg &lock(bool value = true)
.. cpp:function:: arg_locked lock()

Set a flag noting that this argument must be locked when dispatching a
function call in free-threaded Python extensions. It does nothing in
regular GIL-protected extensions.
Return an argument annotation that is like this one but also requests that
this argument be locked when dispatching a function call in free-threaded
Python extensions. It does nothing in regular GIL-protected extensions.

.. cpp:struct:: is_method

Expand Down
102 changes: 83 additions & 19 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ struct name {
};

struct arg_v;
struct arg_locked;
struct arg_locked_v;

// Basic function argument descriptor (no default value, not locked)
struct arg {
NB_INLINE constexpr explicit arg(const char *name = nullptr) : name_(name), signature_(nullptr) { }

// operator= can be used to provide a default value
template <typename T> NB_INLINE arg_v operator=(T &&value) const;

// Mutators that don't change default value or locked state
NB_INLINE arg &noconvert(bool value = true) {
convert_ = !value;
return *this;
Expand All @@ -31,29 +39,75 @@ struct arg {
none_ = value;
return *this;
}

NB_INLINE arg &lock(bool value = true) {
lock_ = value;
return *this;
}

NB_INLINE arg &sig(const char *value) {
signature_ = value;
return *this;
}

// After lock(), this argument is locked
NB_INLINE arg_locked lock();

const char *name_, *signature_;
uint8_t convert_{ true };
bool none_{ false };
bool lock_{ false };
};

// Function argument descriptor with default value (not locked)
struct arg_v : arg {
object value;
NB_INLINE arg_v(const arg &base, object &&value)
: arg(base), value(std::move(value)) {}

private:
// Inherited mutators would slice off the default, and are not generally needed
using arg::noconvert;
using arg::none;
using arg::sig;
using arg::lock;
};

// Function argument descriptor that is locked (no default value)
struct arg_locked : arg {
NB_INLINE constexpr explicit arg_locked(const char *name = nullptr) : arg(name) { }
NB_INLINE constexpr explicit arg_locked(const arg &base) : arg(base) { }

// operator= can be used to provide a default value
template <typename T> NB_INLINE arg_locked_v operator=(T &&value) const;

// Mutators must be respecified in order to not slice off the locked status
NB_INLINE arg_locked &noconvert(bool value = true) {
convert_ = !value;
return *this;
}
NB_INLINE arg_locked &none(bool value = true) {
none_ = value;
return *this;
}
NB_INLINE arg_locked &sig(const char *value) {
signature_ = value;
return *this;
}

// Redundant extra lock() is allowed
NB_INLINE arg_locked &lock() { return *this; }
};

// Function argument descriptor that is potentially locked and has a default value
struct arg_locked_v : arg_locked {
object value;
NB_INLINE arg_locked_v(const arg_locked &base, object &&value)
: arg_locked(base), value(std::move(value)) {}

private:
// Inherited mutators would slice off the default, and are not generally needed
using arg_locked::noconvert;
using arg_locked::none;
using arg_locked::sig;
using arg_locked::lock;
};

NB_INLINE arg_locked arg::lock() { return arg_locked{*this}; }

template <typename... Ts> struct call_guard {
using type = detail::tuple<Ts...>;
};
Expand Down Expand Up @@ -133,9 +187,7 @@ enum class func_flags : uint32_t {
/// Does this overload specify a custom function signature (for docstrings, typing)
has_signature = (1 << 16),
/// Does this function have one or more nb::keep_alive() annotations?
has_keep_alive = (1 << 17),
/// Free-threaded Python: does the binding lock the 'self' argument
lock_self = (1 << 18)
has_keep_alive = (1 << 17)
};

enum cast_flags : uint8_t {
Expand All @@ -148,14 +200,11 @@ enum cast_flags : uint8_t {
// Indicates that the function dispatcher should accept 'None' arguments
accepts_none = (1 << 2),

// Indicates that a function argument must be locked before dispatching a call
lock = (1 << 3),

// Indicates that this cast is performed by nb::cast or nb::try_cast.
// This implies that objects added to the cleanup list may be
// released immediately after the caster's final output value is
// obtained, i.e., before it is used.
manual = (1 << 4)
manual = (1 << 3)
};


Expand Down Expand Up @@ -300,8 +349,6 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) {
flag |= (uint8_t) cast_flags::accepts_none;
if (a.convert_)
flag |= (uint8_t) cast_flags::convert;
if (a.lock_)
flag |= (uint8_t) cast_flags::lock;

arg_data &arg = f.args[index];
arg.flag = flag;
Expand All @@ -310,21 +357,27 @@ NB_INLINE void func_extra_apply(F &f, const arg &a, size_t &index) {
arg.value = nullptr;
index++;
}
// arg_locked will select the arg overload; the locking is added statically
// in nb_func.h

template <typename F>
NB_INLINE void func_extra_apply(F &f, const arg_v &a, size_t &index) {
arg_data &ad = f.args[index];
func_extra_apply(f, (const arg &) a, index);
ad.value = a.value.ptr();
}
template <typename F>
NB_INLINE void func_extra_apply(F &f, const arg_locked_v &a, size_t &index) {
arg_data &ad = f.args[index];
func_extra_apply(f, (const arg_locked &) a, index);
ad.value = a.value.ptr();
}

template <typename F>
NB_INLINE void func_extra_apply(F &, kw_only, size_t &) {}

template <typename F>
NB_INLINE void func_extra_apply(F &f, lock_self, size_t &) {
f.flags |= (uint32_t) func_flags::lock_self;
}
NB_INLINE void func_extra_apply(F &, lock_self, size_t &) {}

template <typename F, typename... Ts>
NB_INLINE void func_extra_apply(F &, call_guard<Ts...>, size_t &) {}
Expand All @@ -337,6 +390,7 @@ NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive<Nurse, Patient>, size
template <typename... Ts> struct func_extra_info {
using call_guard = void;
static constexpr bool keep_alive = false;
static constexpr size_t nargs_locked = 0;
};

template <typename T, typename... Ts> struct func_extra_info<T, Ts...>
Expand All @@ -354,6 +408,16 @@ struct func_extra_info<nanobind::keep_alive<Nurse, Patient>, Ts...> : func_extra
static constexpr bool keep_alive = true;
};

template <typename... Ts>
struct func_extra_info<nanobind::arg_locked, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename... Ts>
struct func_extra_info<nanobind::lock_self, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename T>
NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { }

Expand Down
3 changes: 3 additions & 0 deletions include/nanobind/nb_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ args_proxy api<Derived>::operator*() const {
template <typename T>
NB_INLINE void call_analyze(size_t &nargs, size_t &nkwargs, const T &value) {
using D = std::decay_t<T>;
static_assert(!std::is_base_of_v<arg_locked, D>,
"nb::arg().lock() may be used only when defining functions, "
"not when calling them");

if constexpr (std::is_same_v<D, arg_v>)
nkwargs++;
Expand Down
3 changes: 3 additions & 0 deletions include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ tuple make_tuple(Args &&...args) {
template <typename T> arg_v arg::operator=(T &&value) const {
return arg_v(*this, cast((detail::forward_t<T>) value));
}
template <typename T> arg_locked_v arg_locked::operator=(T &&value) const {
return arg_locked_v(*this, cast((detail::forward_t<T>) value));
}

template <typename Impl> template <typename T>
detail::accessor<Impl>& detail::accessor<Impl>::operator=(T &&value) {
Expand Down
58 changes: 56 additions & 2 deletions include/nanobind/nb_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,40 @@ bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags,
template <size_t I, typename... Ts, size_t... Is>
constexpr size_t count_args_before_index(std::index_sequence<Is...>) {
static_assert(sizeof...(Is) == sizeof...(Ts));
return ((Is < I && (std::is_same_v<arg, Ts> || std::is_same_v<arg_v, Ts>)) + ... + 0);
return ((Is < I && std::is_base_of_v<arg, Ts>) + ... + 0);
}

#if defined(NB_FREE_THREADED)
struct ft_args_collector {
PyObject **args;
handle h1;
handle h2;
size_t index = 0;

NB_INLINE explicit ft_args_collector(PyObject **a) : args(a) {}
NB_INLINE void apply(arg_locked *) {
if (h1.ptr() == nullptr)
h1 = args[index];
h2 = args[index];
++index;
}
NB_INLINE void apply(arg *) { ++index; }
NB_INLINE void apply(...) {}
};

struct ft_args_guard {
NB_INLINE void lock(const ft_args_collector& info) {
PyCriticalSection2_Begin(&cs, info.h1.ptr(), info.h2.ptr());
}
~ft_args_guard() {
PyCriticalSection2_End(&cs);
}
PyCriticalSection2 cs;
};
#endif

struct no_guard {};

template <bool ReturnRef, bool CheckGuard, typename Func, typename Return,
typename... Args, size_t... Is, typename... Extra>
NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
Expand Down Expand Up @@ -66,13 +97,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),

// Determine the number of nb::arg/nb::arg_v annotations
constexpr size_t nargs_provided =
((std::is_same_v<arg, Extra> + std::is_same_v<arg_v, Extra>) + ... + 0);
(std::is_base_of_v<arg, Extra> + ... + 0);
constexpr bool is_method_det =
(std::is_same_v<is_method, Extra> + ... + 0) != 0;
constexpr bool is_getter_det =
(std::is_same_v<is_getter, Extra> + ... + 0) != 0;
constexpr bool has_arg_annotations = nargs_provided > 0 && !is_getter_det;

// Determine the number of potentially-locked function arguments
constexpr bool lock_self_det =
(std::is_same_v<lock_self, Extra> + ... + 0) != 0;
static_assert(Info::nargs_locked <= 2,
"At most two function arguments can be locked");
static_assert(!(lock_self_det && !is_method_det),
"The nb::lock_self() annotation only applies to methods");

// Detect location of nb::kw_only annotation, if supplied. As with args/kwargs
// we find the first and last location and later verify they match each other.
// Note this is an index in Extra... while args/kwargs_pos_* are indices in
Expand Down Expand Up @@ -187,6 +226,21 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
tuple<make_caster<Args>...> in;
(void) in;

#if defined(NB_FREE_THREADED)
std::conditional_t<Info::nargs_locked != 0, ft_args_guard, no_guard> guard;
if constexpr (Info::nargs_locked) {
ft_args_collector collector{args};
if constexpr (is_method_det) {
if constexpr (lock_self_det)
collector.apply((arg_locked *) nullptr);
else
collector.apply((arg *) nullptr);
}
(collector.apply((Extra *) nullptr), ...);
guard.lock(collector);
}
#endif

if constexpr (Info::keep_alive) {
if ((!from_python_keep_alive(in.template get<Is>(), args,
args_flags, cleanup, Is) || ...))
Expand Down
Loading

0 comments on commit 6d60ed7

Please sign in to comment.