Skip to content

Commit

Permalink
Merge #3157
Browse files Browse the repository at this point in the history
3157: Add support for `#[new]` which is also a `#[classmethod]` r=davidhewitt a=stuhood

Fixes #3077.

Co-authored-by: Stu Hood <[email protected]>
  • Loading branch information
bors[bot] and stuhood authored May 17, 2023
2 parents 0f00240 + 20c5618 commit 3b4c7d3
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 10 deletions.
21 changes: 21 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,27 @@ Declares a class method callable from Python.
* For details on `parameter-list`, see the documentation of `Method arguments` section.
* The return type must be `PyResult<T>` or `T` for some `T` that implements `IntoPy<PyObject>`.

### Constructors which accept a class argument

To create a constructor which takes a positional class argument, you can combine the `#[classmethod]` and `#[new]` modifiers:
```rust
# use pyo3::prelude::*;
# use pyo3::types::PyType;
# #[pyclass]
# struct BaseClass(PyObject);
#
#[pymethods]
impl BaseClass {
#[new]
#[classmethod]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
}
}
```

## Static methods

To create a static method for a custom class, the method needs to be annotated with the
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3157.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument.
38 changes: 30 additions & 8 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
pub enum MethodTypeAttribute {
/// `#[new]`
New,
/// `#[new]` && `#[classmethod]`
NewClassMethod,
/// `#[classmethod]`
ClassMethod,
/// `#[classattr]`
Expand All @@ -102,6 +104,7 @@ pub enum FnType {
Setter(SelfType),
Fn(SelfType),
FnNew,
FnNewClass,
FnClass,
FnStatic,
FnModule,
Expand All @@ -122,7 +125,7 @@ impl FnType {
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
}
FnType::FnClass => {
FnType::FnClass | FnType::FnNewClass => {
quote! {
let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject);
}
Expand Down Expand Up @@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
Some(MethodTypeAttribute::New) => {
Some(MethodTypeAttribute::New) | Some(MethodTypeAttribute::NewClassMethod) => {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
(FnType::FnNew, false, Some(CallingConvention::TpNew))
if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) {
(FnType::FnNew, false, Some(CallingConvention::TpNew))
} else {
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
}
}
Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None),
Some(MethodTypeAttribute::Getter) => {
Expand Down Expand Up @@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?;
let call = quote! { #rust_name(#(#args),*) };
let call = match &self.tp {
FnType::FnNew => quote! { #rust_name(#(#args),*) },
FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) },
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
};
quote! {
unsafe fn #ident(
#py: _pyo3::Python<'_>,
Expand Down Expand Up @@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
FnType::Fn(_) => Some("self"),
FnType::FnModule => Some("module"),
FnType::FnClass => Some("cls"),
FnType::FnClass | FnType::FnNewClass => Some("cls"),
FnType::FnStatic | FnType::FnNew => None,
};

Expand Down Expand Up @@ -637,11 +648,22 @@ fn parse_method_attributes(
let mut deprecated_args = None;
let mut ty: Option<MethodTypeAttribute> = None;

macro_rules! set_compound_ty {
($new_ty:expr, $ident:expr) => {
ty = match (ty, $new_ty) {
(None, new_ty) => Some(new_ty),
(Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod),
(Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod),
(Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"),
};
};
}

macro_rules! set_ty {
($new_ty:expr, $ident:expr) => {
ensure_spanned!(
ty.replace($new_ty).is_none(),
$ident.span() => "cannot specify a second method type"
$ident.span() => "cannot combine these method types"
);
};
}
Expand All @@ -650,13 +672,13 @@ fn parse_method_attributes(
match attr.parse_meta() {
Ok(syn::Meta::Path(name)) => {
if name.is_ident("new") || name.is_ident("__new__") {
set_ty!(MethodTypeAttribute::New, name);
set_compound_ty!(MethodTypeAttribute::New, name);
} else if name.is_ident("init") || name.is_ident("__init__") {
bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0");
} else if name.is_ident("call") || name.is_ident("__call__") {
bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0");
} else if name.is_ident("classmethod") {
set_ty!(MethodTypeAttribute::ClassMethod, name);
set_compound_ty!(MethodTypeAttribute::ClassMethod, name);
} else if name.is_ident("staticmethod") {
set_ty!(MethodTypeAttribute::StaticMethod, name);
} else if name.is_ident("classattr") {
Expand Down
4 changes: 3 additions & 1 deletion pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ pub fn gen_py_method(
Some(quote!(_pyo3::ffi::METH_STATIC)),
)?),
// special prototypes
(_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?),
(_, FnType::FnNew) | (_, FnType::FnNewClass) => {
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
}

(_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def(
cls,
Expand Down
23 changes: 23 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use pyo3::exceptions::PyValueError;
use pyo3::iter::IterNextOutput;
use pyo3::prelude::*;
use pyo3::types::PyType;

#[pyclass]
struct EmptyClass {}
Expand Down Expand Up @@ -35,9 +37,30 @@ impl PyClassIter {
}
}

/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
#[pyclass(subclass)]
#[derive(Clone, Debug)]
struct AssertingBaseClass;

#[pymethods]
impl AssertingBaseClass {
#[new]
#[classmethod]
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
if !cls.is(expected_type) {
return Err(PyValueError::new_err(format!(
"{:?} != {:?}",
cls, expected_type
)));
}
Ok(Self)
}
}

#[pymodule]
pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
Ok(())
}
11 changes: 11 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ def test_iter():
with pytest.raises(StopIteration) as excinfo:
next(i)
assert excinfo.value.value == "Ended"


class AssertingSubClass(pyclasses.AssertingBaseClass):
pass


def test_new_classmethod():
# The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)
2 changes: 1 addition & 1 deletion tests/ui/invalid_pymethods.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ error: `signature` not allowed with `classattr`
105 | #[pyo3(signature = ())]
| ^^^^^^^^^

error: cannot specify a second method type
error: cannot combine these method types
--> tests/ui/invalid_pymethods.rs:112:7
|
112 | #[staticmethod]
Expand Down

0 comments on commit 3b4c7d3

Please sign in to comment.