From 45e163d98e994812cc73cba773320ddcae541176 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sun, 29 Mar 2020 23:47:02 +0900 Subject: [PATCH] Use METH_COEXIST for protocols & Make some methods take optional args --- pyo3-derive-backend/src/defs.rs | 11 +++--- pyo3-derive-backend/src/pyproto.rs | 4 +- src/class/number.rs | 13 +++--- tests/test_arithmetics.rs | 63 ++++++++++++++++++++++-------- 4 files changed, 61 insertions(+), 30 deletions(-) diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 4d07e52fc65..17401c4d95b 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -631,6 +631,12 @@ pub const NUM: Proto = Proto { pyres: false, proto: "pyo3::class::number::PyNumberIOrProtocol", }, + MethodProto::Binary { + name: "__round__", + arg: "Ndigits", + pyres: true, + proto: "pyo3::class::number::PyNumberRoundProtocol", + }, MethodProto::Unary { name: "__neg__", pyres: true, @@ -666,11 +672,6 @@ pub const NUM: Proto = Proto { pyres: true, proto: "pyo3::class::number::PyNumberFloatProtocol", }, - MethodProto::Unary { - name: "__round__", - pyres: true, - proto: "pyo3::class::number::PyNumberRoundProtocol", - }, MethodProto::Unary { name: "__index__", pyres: true, diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index 8416dfb2e97..c64c290ccd3 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -75,7 +75,6 @@ fn impl_proto_impl( Err(err) => return err.to_compile_error(), }; let meth = pymethod::impl_proto_wrap(ty, &fn_spec); - py_methods.push(quote! { impl #proto for #ty { @@ -86,7 +85,8 @@ fn impl_proto_impl( Some(pyo3::class::PyMethodDef { ml_name: stringify!(#name), ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, + // We need METH_COEXIST here to prevent __add__ from overriding __radd__ + ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | pyo3::ffi::METH_COEXIST, ml_doc: "" }) } diff --git a/src/class/number.rs b/src/class/number.rs index ed9ef3fe0ba..9d803092c79 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -60,7 +60,7 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Self::Modulo) -> Self::Result + fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Option) -> Self::Result where Self: PyNumberPowProtocol<'p>, { @@ -145,7 +145,7 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __rpow__(&'p self, other: Self::Other, module: Self::Modulo) -> Self::Result + fn __rpow__(&'p self, other: Self::Other, modulo: Option) -> Self::Result where Self: PyNumberRPowProtocol<'p>, { @@ -224,7 +224,7 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __ipow__(&'p mut self, other: Self::Other, modulo: Self::Modulo) -> Self::Result + fn __ipow__(&'p mut self, other: Self::Other, modulo: Option) -> Self::Result where Self: PyNumberIPowProtocol<'p>, { @@ -304,7 +304,7 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __round__(&'p self) -> Self::Result + fn __round__(&'p self, ndigits: Option) -> Self::Result where Self: PyNumberRoundProtocol<'p>, { @@ -610,6 +610,7 @@ pub trait PyNumberFloatProtocol<'p>: PyNumberProtocol<'p> { pub trait PyNumberRoundProtocol<'p>: PyNumberProtocol<'p> { type Success: IntoPy; + type Ndigits: FromPyObject<'p>; type Result: Into>; } @@ -2137,7 +2138,7 @@ where } } -trait PyNumberComplexProtocolImpl { +pub trait PyNumberComplexProtocolImpl { fn __complex__() -> Option; } @@ -2150,7 +2151,7 @@ where } } -trait PyNumberRoundProtocolImpl { +pub trait PyNumberRoundProtocolImpl { fn __round__() -> Option; } diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 39367406cc7..9fc2cecec82 100755 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -8,24 +8,39 @@ use pyo3::py_run; mod common; #[pyclass] -struct UnaryArithmetic {} +struct UnaryArithmetic { + inner: f64, +} + +impl UnaryArithmetic { + fn new(value: f64) -> Self { + UnaryArithmetic { inner: value } + } +} + +#[pyproto] +impl PyObjectProtocol for UnaryArithmetic { + fn __repr__(&self) -> PyResult { + Ok(format!("UA({})", self.inner)) + } +} #[pyproto] impl PyNumberProtocol for UnaryArithmetic { - fn __neg__(&self) -> PyResult<&'static str> { - Ok("neg") + fn __neg__(&self) -> PyResult { + Ok(Self::new(-self.inner)) } - fn __pos__(&self) -> PyResult<&'static str> { - Ok("pos") + fn __pos__(&self) -> PyResult { + Ok(Self::new(self.inner)) } - fn __abs__(&self) -> PyResult<&'static str> { - Ok("abs") + fn __abs__(&self) -> PyResult { + Ok(Self::new(self.inner.abs())) } - fn __invert__(&self) -> PyResult<&'static str> { - Ok("invert") + fn __round__(&self, _ndigits: Option) -> PyResult { + Ok(Self::new(self.inner.round())) } } @@ -34,11 +49,11 @@ fn unary_arithmetic() { let gil = Python::acquire_gil(); let py = gil.python(); - let c = PyCell::new(py, UnaryArithmetic {}).unwrap(); - py_run!(py, c, "assert -c == 'neg'"); - py_run!(py, c, "assert +c == 'pos'"); - py_run!(py, c, "assert abs(c) == 'abs'"); - py_run!(py, c, "assert ~c == 'invert'"); + let c = PyCell::new(py, UnaryArithmetic::new(2.718281)).unwrap(); + py_run!(py, c, "assert repr(-c) == 'UA(-2.718281)'"); + py_run!(py, c, "assert repr(+c) == 'UA(2.718281)'"); + py_run!(py, c, "assert repr(abs(c)) == 'UA(2.718281)'"); + py_run!(py, c, "assert repr(round(c)) == 'UA(3)'"); } #[pyclass] @@ -104,6 +119,11 @@ impl PyNumberProtocol for InPlaceOperations { self.value |= other; Ok(()) } + + fn __ipow__(&mut self, other: u32, _mod: Option) -> PyResult<()> { + self.value = self.value.pow(other); + Ok(()) + } } #[test] @@ -124,6 +144,7 @@ fn inplace_operations() { init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'"); init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'"); init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'"); } #[pyproto] @@ -159,6 +180,10 @@ impl PyNumberProtocol for BinaryArithmetic { fn __or__(lhs: &PyAny, rhs: &PyAny) -> PyResult { Ok(format!("{:?} | {:?}", lhs, rhs)) } + + fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option) -> PyResult { + Ok(format!("{:?} ** {:?}", lhs, rhs)) + } } #[test] @@ -186,6 +211,8 @@ fn binary_arithmetic() { py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); py_run!(py, c, "assert c | 1 == 'BA | 1'"); py_run!(py, c, "assert 1 | c == '1 | BA'"); + py_run!(py, c, "assert 1 ** c == '1 ** BA'"); + py_run!(py, c, "assert c ** 1 == 'BA ** 1'"); } #[pyclass] @@ -225,7 +252,7 @@ impl PyNumberProtocol for RhsArithmetic { Ok(format!("{:?} | RA", other)) } - fn __rpow__(&self, other: &PyAny, _module: &PyAny) -> PyResult { + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult { Ok(format!("{:?} ** RA", other)) } } @@ -291,9 +318,11 @@ fn lhs_override_rhs() { let py = gil.python(); let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap(); - py_run!(py, c, "assert c.__radd__(1) == '1 + BA'"); + // Not overrided + py_run!(py, c, "assert c.__radd__(1) == '1 + RA'"); + py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); + // Overrided py_run!(py, c, "assert 1 + c == '1 + BA'"); - py_run!(py, c, "assert c.__rsub__(1) == '1 - BA'"); py_run!(py, c, "assert 1 - c == '1 - BA'"); }