Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Implement Multi-Phase Module Initialization as per PEP 489 #4162

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,7 @@ pub fn pymodule_module_impl(
unsafe {
impl_::ModuleDef::new(
__PYO3_NAME,
#doc,
INITIALIZER
#doc
)
}
}};
Expand Down Expand Up @@ -385,6 +384,9 @@ pub fn pymodule_function_impl(

let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false);

// Each generated `module_objects_init` function is exported as a separate symbol.
let module_objects_init_symbol = format!("__module_objects_init__{}", ident.unraw());

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
if function.sig.inputs.len() == 2 {
Expand All @@ -396,6 +398,32 @@ pub fn pymodule_function_impl(
Ok(quote! {
#[doc(hidden)]
#vis mod #ident {
/// Function used to add classes, functions, etc. to the module during
/// multi-phase initialization.
#[doc(hidden)]
#[export_name = #module_objects_init_symbol]
pub unsafe extern "C" fn __module_objects_init(module: *mut #pyo3_path::ffi::PyObject) -> ::std::ffi::c_int {
let module = unsafe {
let nonnull = ::std::ptr::NonNull::new(module).expect("'module' shouldn't be NULL");
#pyo3_path::Py::<#pyo3_path::types::PyModule>::from_non_null(nonnull)
};

let res = unsafe {
#pyo3_path::Python::with_gil_unchecked(|py| {
let bound = module.bind(py);
MakeDef::do_init_multiphase(bound)
})
};

// FIXME: Better error handling
let _ = res.unwrap();

0
}

#[doc(hidden)]
pub const __PYO3_INIT: *mut ::std::ffi::c_void = __module_objects_init as *mut ::std::ffi::c_void;

#initialization
}

Expand All @@ -405,17 +433,22 @@ pub fn pymodule_function_impl(
// inside a function body)
#[allow(unknown_lints, non_local_definitions)]
impl #ident::MakeDef {
/// Helper function for `__module_objects_init`. Should probably be put
/// somewhere else.
#[doc(hidden)]
pub fn do_init_multiphase(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
#ident(#(#module_args),*)
}

const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
unsafe {
#pyo3_path::impl_::pymodule::ModuleDef::new(
#ident::__PYO3_NAME,
#doc,
INITIALIZER
)
}
}
Expand All @@ -442,14 +475,31 @@ fn module_initialization(
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
};

if !is_submodule {
result.extend(quote! {
#[doc(hidden)]
pub static _PYO3_SLOTS: &[#pyo3_path::impl_::pymodule_state::ModuleDefSlot] = &[
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::start(),
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::new(
#pyo3_path::ffi::Py_mod_exec,
__PYO3_INIT,
),
#[cfg(Py_3_12)]
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::per_interpreter_gil(),
#pyo3_path::impl_::pymodule_state::ModuleDefSlot::end(),
];

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
#pyo3_path::impl_::trampoline::module_init(|py| {
let slots = #pyo3_path::impl_::pymodule_state::ModuleDefSlots::new_from_static(_PYO3_SLOTS);
_PYO3_DEF.set_multiphase_items(slots);
_PYO3_DEF.make_module(py);
})
}
});
}
Expand Down
1 change: 1 addition & 0 deletions src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod pyclass;
pub mod pyfunction;
pub mod pymethods;
pub mod pymodule;
pub mod pymodule_state;
#[doc(hidden)]
pub mod trampoline;
pub mod wrap;
167 changes: 102 additions & 65 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,33 @@ use crate::{
Bound, Py, PyClass, PyMethodDef, PyResult, PyTypeInfo, Python,
};

use crate::impl_::pymodule_state as state;

// TODO: replace other usages (if this passes review :^) )
pub use state::ModuleDefSlot;

/// `Sync` wrapper of `ffi::PyModuleDef`.
pub struct ModuleDef {
// wrapped in UnsafeCell so that Rust compiler treats this as interior mutability
ffi_def: UnsafeCell<ffi::PyModuleDef>,
initializer: ModuleInitializer,
/// Interpreter ID where module was initialized (not applicable on PyPy).
#[cfg(all(
not(any(PyPy, GraalPy)),
Py_3_9,
not(all(windows, Py_LIMITED_API, not(Py_3_10)))
))]
interpreter: AtomicI64,
// TODO: `module` could probably go..?
/// Initialized module object, cached to avoid reinitialization.
#[allow(unused)]
module: GILOnceCell<Py<PyModule>>,
}

/// Wrapper to enable initializer to be used in const fns.
pub struct ModuleInitializer(pub for<'py> fn(&Bound<'py, PyModule>) -> PyResult<()>);

unsafe impl Sync for ModuleDef {}

impl ModuleDef {
/// Make new module definition with given module name.
pub const unsafe fn new(
name: &'static CStr,
doc: &'static CStr,
initializer: ModuleInitializer,
) -> Self {
pub const unsafe fn new(name: &'static CStr, doc: &'static CStr) -> Self {
const INIT: ffi::PyModuleDef = ffi::PyModuleDef {
m_base: ffi::PyModuleDef_HEAD_INIT,
m_name: std::ptr::null(),
Expand All @@ -74,7 +73,6 @@ impl ModuleDef {

ModuleDef {
ffi_def,
initializer,
// -1 is never expected to be a valid interpreter ID
#[cfg(all(
not(any(PyPy, GraalPy)),
Expand All @@ -85,8 +83,9 @@ impl ModuleDef {
module: GILOnceCell::new(),
}
}

/// Builds a module using user given initializer. Used for [`#[pymodule]`][crate::pymodule].
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<Py<PyModule>> {
pub fn make_module(&'static self, py: Python<'_>) -> PyResult<*mut ffi::PyModuleDef> {
#[cfg(all(PyPy, not(Py_3_8)))]
{
use crate::types::any::PyAnyMethods;
Expand Down Expand Up @@ -140,18 +139,31 @@ impl ModuleDef {
}
}
}
self.module
.get_or_try_init(py, || {
let module = unsafe {
Py::<PyModule>::from_owned_ptr_or_err(
py,
ffi::PyModule_Create(self.ffi_def.get()),
)?
};
self.initializer.0(module.bind(py))?;
Ok(module)
})
.map(|py_module| py_module.clone_ref(py))

if (unsafe { *self.ffi_def.get() }).m_slots.is_null() {
return Err(PyImportError::new_err(
"'m_slots' of module definition is NULL",
));
}

let module_def_ptr = unsafe { ffi::PyModuleDef_Init(self.ffi_def.get()) };

if module_def_ptr.is_null() {
return Err(PyImportError::new_err("PyModuleDef_Init returned NULL"));
}

Ok(module_def_ptr.cast())
}

pub fn set_multiphase_items(&'static self, slots: state::ModuleDefSlots) {
let ffi_def = self.ffi_def.get();
unsafe {
(*ffi_def).m_size = std::mem::size_of::<state::ModuleState>() as ffi::Py_ssize_t;
(*ffi_def).m_slots = slots.into_inner();
(*ffi_def).m_traverse = Some(state::module_state_traverse);
(*ffi_def).m_clear = Some(state::module_state_clear);
(*ffi_def).m_free = Some(state::module_state_free);
};
}
}

Expand Down Expand Up @@ -204,7 +216,44 @@ impl PyAddToModule for PyMethodDef {
/// For adding a module to a module.
impl PyAddToModule for ModuleDef {
fn add_to_module(&'static self, module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_submodule(self.make_module(module.py())?.bind(module.py()))
let parent_ptr = module.as_ptr();
let parent_name = std::ffi::CString::new(module.name()?.to_string())?;

let add_to_parent = |child_ptr: *mut ffi::PyObject| -> std::ffi::c_int {
// TODO: reference to child_ptr is stolen - check if this is fine here?
let ret =
unsafe { ffi::PyModule_AddObject(parent_ptr, parent_name.as_ptr(), child_ptr) };

// TODO: .. as well as this error handling here - is this fine
// inside Py_mod_exec slots?
if ret < 0 {
unsafe { ffi::Py_DECREF(parent_ptr) };
return -1;
}

0
};

// SAFETY: We only use this closure inside the ModuleDef's slots and
// then immediately initialize the module - this closure /
// "function pointer" isn't used anywhere else afterwards and can't
// outlive the current thread.
let add_to_parent = unsafe { state::alloc_closure(add_to_parent) };

let slots = [
state::ModuleDefSlot::start(),
state::ModuleDefSlot::new(ffi::Py_mod_exec, add_to_parent),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

let slots = state::alloc_slots(slots);
self.set_multiphase_items(slots);

let _module_def_ptr = self.make_module(module.py())?;

Ok(())
}
}

Expand All @@ -218,50 +267,31 @@ mod tests {

use crate::{
ffi,
impl_::pymodule_state as state,
types::{any::PyAnyMethods, module::PyModuleMethods, PyModule},
Bound, PyResult, Python,
};

use super::{ModuleDef, ModuleInitializer};
use super::ModuleDef;

#[test]
fn module_init() {
static MODULE_DEF: ModuleDef = unsafe {
ModuleDef::new(
ffi::c_str!("test_module"),
ffi::c_str!("some doc"),
ModuleInitializer(|m| {
m.add("SOME_CONSTANT", 42)?;
Ok(())
}),
)
};
static MODULE_DEF: ModuleDef =
unsafe { ModuleDef::new(ffi::c_str!("test_module"), ffi::c_str!("some doc")) };

let slots = [
state::ModuleDefSlot::start(),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

MODULE_DEF.set_multiphase_items(state::alloc_slots(slots));

Python::with_gil(|py| {
let module = MODULE_DEF.make_module(py).unwrap().into_bound(py);
assert_eq!(
module
.getattr("__name__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"test_module",
);
assert_eq!(
module
.getattr("__doc__")
.unwrap()
.extract::<Cow<'_, str>>()
.unwrap(),
"some doc",
);
assert_eq!(
module
.getattr("SOME_CONSTANT")
.unwrap()
.extract::<u8>()
.unwrap(),
42,
);
let module_def = MODULE_DEF.make_module(py).unwrap();
// FIXME: get PyModule from PyModuleDef ..?
unimplemented!("Test currently not implemented");
})
}

Expand All @@ -272,6 +302,13 @@ mod tests {
static NAME: &CStr = ffi::c_str!("test_module");
static DOC: &CStr = ffi::c_str!("some doc");

let slots = [
state::ModuleDefSlot::start(),
#[cfg(Py_3_12)]
state::ModuleDefSlot::per_interpreter_gil(),
state::ModuleDefSlot::end(),
];

static INIT_CALLED: AtomicBool = AtomicBool::new(false);

#[allow(clippy::unnecessary_wraps)]
Expand All @@ -281,12 +318,12 @@ mod tests {
}

unsafe {
let module_def: ModuleDef = ModuleDef::new(NAME, DOC, ModuleInitializer(init));
assert_eq!((*module_def.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*module_def.ffi_def.get()).m_doc, DOC.as_ptr() as _);
static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new(NAME, DOC) };
MODULE_DEF.set_multiphase_items(state::alloc_slots(slots));
assert_eq!((*MODULE_DEF.ffi_def.get()).m_name, NAME.as_ptr() as _);
assert_eq!((*MODULE_DEF.ffi_def.get()).m_doc, DOC.as_ptr() as _);

Python::with_gil(|py| {
module_def.initializer.0(&py.import_bound("builtins").unwrap()).unwrap();
Python::with_gil(|_py| {
assert!(INIT_CALLED.load(Ordering::SeqCst));
})
}
Expand Down
Loading
Loading