Skip to content

Commit

Permalink
Use METH_COEXIST for protocols & Make some methods take optional args
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Mar 29, 2020
1 parent a76bd7c commit 45e163d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 30 deletions.
11 changes: 6 additions & 5 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,12 @@ pub const NUM: Proto = Proto {
pyres: false,
proto: "pyo3::class::number::PyNumberIOrProtocol",
},
MethodProto::Binary {
name: "__round__",
arg: "Ndigits",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
MethodProto::Unary {
name: "__neg__",
pyres: true,
Expand Down Expand Up @@ -666,11 +672,6 @@ pub const NUM: Proto = Proto {
pyres: true,
proto: "pyo3::class::number::PyNumberFloatProtocol",
},
MethodProto::Unary {
name: "__round__",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
MethodProto::Unary {
name: "__index__",
pyres: true,
Expand Down
4 changes: 2 additions & 2 deletions pyo3-derive-backend/src/pyproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ fn impl_proto_impl(
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);

py_methods.push(quote! {
impl #proto for #ty
{
Expand All @@ -86,7 +85,8 @@ fn impl_proto_impl(
Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
// We need METH_COEXIST here to prevent __add__ from overriding __radd__
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | pyo3::ffi::METH_COEXIST,
ml_doc: ""
})
}
Expand Down
13 changes: 7 additions & 6 deletions src/class/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Self::Modulo) -> Self::Result
fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberPowProtocol<'p>,
{
Expand Down Expand Up @@ -145,7 +145,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __rpow__(&'p self, other: Self::Other, module: Self::Modulo) -> Self::Result
fn __rpow__(&'p self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberRPowProtocol<'p>,
{
Expand Down Expand Up @@ -224,7 +224,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __ipow__(&'p mut self, other: Self::Other, modulo: Self::Modulo) -> Self::Result
fn __ipow__(&'p mut self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberIPowProtocol<'p>,
{
Expand Down Expand Up @@ -304,7 +304,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __round__(&'p self) -> Self::Result
fn __round__(&'p self, ndigits: Option<Self::Ndigits>) -> Self::Result
where
Self: PyNumberRoundProtocol<'p>,
{
Expand Down Expand Up @@ -610,6 +610,7 @@ pub trait PyNumberFloatProtocol<'p>: PyNumberProtocol<'p> {

pub trait PyNumberRoundProtocol<'p>: PyNumberProtocol<'p> {
type Success: IntoPy<PyObject>;
type Ndigits: FromPyObject<'p>;
type Result: Into<PyResult<Self::Success>>;
}

Expand Down Expand Up @@ -2137,7 +2138,7 @@ where
}
}

trait PyNumberComplexProtocolImpl {
pub trait PyNumberComplexProtocolImpl {
fn __complex__() -> Option<PyMethodDef>;
}

Expand All @@ -2150,7 +2151,7 @@ where
}
}

trait PyNumberRoundProtocolImpl {
pub trait PyNumberRoundProtocolImpl {
fn __round__() -> Option<PyMethodDef>;
}

Expand Down
63 changes: 46 additions & 17 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,39 @@ use pyo3::py_run;
mod common;

#[pyclass]
struct UnaryArithmetic {}
struct UnaryArithmetic {
inner: f64,
}

impl UnaryArithmetic {
fn new(value: f64) -> Self {
UnaryArithmetic { inner: value }
}
}

#[pyproto]
impl PyObjectProtocol for UnaryArithmetic {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("UA({})", self.inner))
}
}

#[pyproto]
impl PyNumberProtocol for UnaryArithmetic {
fn __neg__(&self) -> PyResult<&'static str> {
Ok("neg")
fn __neg__(&self) -> PyResult<Self> {
Ok(Self::new(-self.inner))
}

fn __pos__(&self) -> PyResult<&'static str> {
Ok("pos")
fn __pos__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner))
}

fn __abs__(&self) -> PyResult<&'static str> {
Ok("abs")
fn __abs__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner.abs()))
}

fn __invert__(&self) -> PyResult<&'static str> {
Ok("invert")
fn __round__(&self, _ndigits: Option<u32>) -> PyResult<Self> {
Ok(Self::new(self.inner.round()))
}
}

Expand All @@ -34,11 +49,11 @@ fn unary_arithmetic() {
let gil = Python::acquire_gil();
let py = gil.python();

let c = PyCell::new(py, UnaryArithmetic {}).unwrap();
py_run!(py, c, "assert -c == 'neg'");
py_run!(py, c, "assert +c == 'pos'");
py_run!(py, c, "assert abs(c) == 'abs'");
py_run!(py, c, "assert ~c == 'invert'");
let c = PyCell::new(py, UnaryArithmetic::new(2.718281)).unwrap();
py_run!(py, c, "assert repr(-c) == 'UA(-2.718281)'");
py_run!(py, c, "assert repr(+c) == 'UA(2.718281)'");
py_run!(py, c, "assert repr(abs(c)) == 'UA(2.718281)'");
py_run!(py, c, "assert repr(round(c)) == 'UA(3)'");
}

#[pyclass]
Expand Down Expand Up @@ -104,6 +119,11 @@ impl PyNumberProtocol for InPlaceOperations {
self.value |= other;
Ok(())
}

fn __ipow__(&mut self, other: u32, _mod: Option<u32>) -> PyResult<()> {
self.value = self.value.pow(other);
Ok(())
}
}

#[test]
Expand All @@ -124,6 +144,7 @@ fn inplace_operations() {
init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'");
init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'");
init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'");
init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'");
}

#[pyproto]
Expand Down Expand Up @@ -159,6 +180,10 @@ impl PyNumberProtocol for BinaryArithmetic {
fn __or__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | {:?}", lhs, rhs))
}

fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?}", lhs, rhs))
}
}

#[test]
Expand Down Expand Up @@ -186,6 +211,8 @@ fn binary_arithmetic() {
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert c | 1 == 'BA | 1'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'");
py_run!(py, c, "assert c ** 1 == 'BA ** 1'");
}

#[pyclass]
Expand Down Expand Up @@ -225,7 +252,7 @@ impl PyNumberProtocol for RhsArithmetic {
Ok(format!("{:?} | RA", other))
}

fn __rpow__(&self, other: &PyAny, _module: &PyAny) -> PyResult<String> {
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
}
}
Expand Down Expand Up @@ -291,9 +318,11 @@ fn lhs_override_rhs() {
let py = gil.python();

let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap();
py_run!(py, c, "assert c.__radd__(1) == '1 + BA'");
// Not overrided
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
// Overrided
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
}

Expand Down

0 comments on commit 45e163d

Please sign in to comment.