Skip to content

Commit

Permalink
add PyAnyMethods for binary operators
Browse files Browse the repository at this point in the history
also pow

fixes PyO3#3709
  • Loading branch information
alex committed Dec 29, 2023
1 parent 6776b90 commit 30f6cd0
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
84 changes: 84 additions & 0 deletions src/types/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,38 @@ pub trait PyAnyMethods<'py> {
where
O: ToPyObject;

/// Computes `self + other`.
fn add<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self - other`.
fn sub<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self * other`.
fn mul<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self / other`.
fn div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject;

/// Computes `self & other`.
fn bitand<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;

/// Determines whether this object appears callable.
///
/// This is equivalent to Python's [`callable()`][1] function.
Expand Down Expand Up @@ -1680,6 +1712,26 @@ pub trait PyAnyMethods<'py> {
fn py_super(&self) -> PyResult<Bound<'py, PySuper>>;
}

macro_rules! implement_binop {
($name:ident, $c_api:ident, $op:expr) => {
#[doc = concat!("Computes `self ", $op, " other`.")]
fn $name<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) }
}

let py = self.py();
inner(self, other.to_object(py).into_bound(py))
}
};
}

impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
#[inline]
fn is<T: AsPyPointer>(&self, other: &T) -> bool {
Expand Down Expand Up @@ -1855,6 +1907,38 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
.and_then(|any| any.is_truthy())
}

implement_binop!(add, PyNumber_Add, "+");
implement_binop!(sub, PyNumber_Subtract, "-");
implement_binop!(mul, PyNumber_Multiply, "*");
implement_binop!(div, PyNumber_TrueDivide, "/");
implement_binop!(bitand, PyNumber_And, "&");

/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
modulus: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe {
ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr())
.assume_owned_or_err(any.py())
}
}

let py = self.py();
inner(
self,
other.to_object(py).into_bound(py),
modulus.to_object(py).into_bound(py),
)
}

fn is_callable(&self) -> bool {
unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 }
}
Expand Down
7 changes: 7 additions & 0 deletions tests/test_arithmetics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ fn binary_arithmetic() {
py_expect_exception!(py, c, "1 ** c", PyTypeError);

py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");

let c: Bound<'_, PyAny> = c.extract().unwrap();
assert_eq!(c.add(&c).unwrap().extract::<&str>().unwrap(), "BA + BA");
assert_eq!(
c.pow(&c, py.None()).unwrap().extract::<&str>().unwrap(),
"BA ** BA (mod: None)"
);
});
}

Expand Down

0 comments on commit 30f6cd0

Please sign in to comment.