Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Numpy 2.x (take 2) #442

Merged
merged 22 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
39f0b6c
Add feature flags for numpy 1 and 2
aMarcireau Jun 3, 2024
3b6c0b7
Remove outdated 1.0 functions, add new 2.0 functions
aMarcireau Jun 3, 2024
c2e55f8
Use the new "universal" access functions
aMarcireau Jun 3, 2024
e060735
Fix runtime tests
aMarcireau Jun 4, 2024
31924fb
Remove feature flags and always check the version at runtime where ap…
aMarcireau Jun 7, 2024
7f59548
Avoid API changes by using self.py()
aMarcireau Jun 8, 2024
9d7e0a2
Fixup
aMarcireau Jun 8, 2024
3777299
Fix flags for Windows
aMarcireau Jul 8, 2024
c89932b
Avoid half-open range patterns to appease our MSRV build.
adamreichold Jul 21, 2024
92b0947
Remove residual numpy-1/2 feature usage.
adamreichold Jul 21, 2024
9de97ba
Use explicit patterns so exhaustiveness checking works, but avoid ope…
adamreichold Jul 21, 2024
be3ce9b
Add ci job to test numpy2 and allow numpy2 in examples
maffoo Aug 31, 2024
1a93b96
Make size and buf fields public in npy_static_string
maffoo Aug 31, 2024
2882c6a
Fixes from @adamreichold review
maffoo Aug 31, 2024
61fb9e0
Simplify impl_api macro
maffoo Aug 31, 2024
17eda7f
lint
maffoo Sep 3, 2024
abebd0f
Fixes from review
maffoo Sep 20, 2024
da0807f
Apply suggestions from code review
davidhewitt Sep 23, 2024
6013343
Fix dtypes tests to work with numpy 1 or 2
maffoo Sep 23, 2024
5dd5280
Rename to _PyArray_DescrNumPy2 and _PyArray_LegacyDescr to match c code
maffoo Oct 3, 2024
4322102
Fail compilation on 32-bit windows
maffoo Oct 3, 2024
d09a35d
Add a changelog entry
maffoo Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ jobs:
{ os: "ubuntu-latest", python-architecture: "x64", rust-target: "x86_64-unknown-linux-gnu" },
{ os: "macOS-13", python-architecture: "x64", rust-target: "x86_64-apple-darwin" },
{ os: "windows-latest", python-architecture: "x64", rust-target: "x86_64-pc-windows-msvc" },
{ os: "windows-latest", python-architecture: "x86", rust-target: "i686-pc-windows-msvc" },
]
include:
# Older versions of CPython are not available for AArch64.
Expand Down Expand Up @@ -70,7 +69,7 @@ jobs:
shell: python
- name: Test
run: |
pip install "numpy<2" ml_dtypes
pip install "numpy" ml_dtypes
cargo test --all-features
# Not on PyPy, because no embedding API
if: ${{ !startsWith(matrix.python-version, 'pypy') }}
Expand All @@ -83,6 +82,52 @@ jobs:
CARGO_BUILD_TARGET: ${{ matrix.platform.rust-target }}
RUST_BACKTRACE: 1

test-numpy1:
name: python${{ matrix.python-version }}-${{ matrix.platform.python-architecture }} ${{ matrix.platform.os }} numpy1
runs-on: ${{ matrix.platform.os }}
needs: [lint, check-msrv, examples]
strategy:
fail-fast: ${{ !contains(github.event.pull_request.labels.*.name, 'CI-no-fail-fast') }}
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
platform: [
{ os: "ubuntu-latest", python-architecture: "x64", rust-target: "x86_64-unknown-linux-gnu" },
{ os: "macOS-13", python-architecture: "x64", rust-target: "x86_64-apple-darwin" },
{ os: "windows-latest", python-architecture: "x64", rust-target: "x86_64-pc-windows-msvc" },
]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
architecture: ${{ matrix.platform.python-architecture }}
- name: Install Rust
uses: dtolnay/rust-toolchain@stable
with:
targets: ${{ matrix.platform.rust-target }}
- name: Install toml
run: pip install toml
- name: Edit Cargo.toml and enable new resolver
run: |
import toml
cargo_toml = toml.load("Cargo.toml")
cargo_toml["package"]["resolver"] = "2"
with open("Cargo.toml", "w") as f:
toml.dump(cargo_toml, f)
Comment on lines +111 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate task, I think we can bump MSRV and remove this.

shell: python
- name: Test
run: |
pip install "numpy<2" ml_dtypes
cargo test --all-features
- name: Test example
run: |
pip install nox
nox -f examples/simple/noxfile.py
env:
CARGO_TERM_VERBOSE: true
RUST_BACKTRACE: 1

cross-build:
runs-on: ubuntu-latest
needs: [lint, check-msrv, examples]
Expand All @@ -104,7 +149,7 @@ jobs:
continue-on-error: true
- uses: taiki-e/install-action@valgrind
- run: |
pip install "numpy<2" ml_dtypes
pip install "numpy" ml_dtypes
cargo test --all-features --release
env:
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1
Expand All @@ -119,7 +164,7 @@ jobs:
continue-on-error: true
- uses: taiki-e/install-action@cargo-careful
- run: |
pip install "numpy<2" ml_dtypes
pip install "numpy" ml_dtypes
cargo careful test --all-features

check-msrv:
Expand Down Expand Up @@ -195,7 +240,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install numpy
run: pip install "numpy<2" ml_dtypes
run: pip install "numpy" ml_dtypes
- uses: Swatinem/rust-cache@v2
continue-on-error: true
- uses: dtolnay/rust-toolchain@stable
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Unreleased
- Add `permute` and `transpose` methods for changing the order of axes of a `PyArray`. ([#428](https://github.com/PyO3/rust-numpy/pull/428))
- Add support for NumPy v2 which had a number of changes to the [C API](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#c-api-changes). ([#442](https://github.com/PyO3/rust-numpy/pull/442))

- v0.21.0
- Migrate to the new `Bound` API introduced by PyO3 0.21. ([#410](https://github.com/PyO3/rust-numpy/pull/410)) ([#411](https://github.com/PyO3/rust-numpy/pull/411)) ([#412](https://github.com/PyO3/rust-numpy/pull/412)) ([#415](https://github.com/PyO3/rust-numpy/pull/415)) ([#416](https://github.com/PyO3/rust-numpy/pull/416)) ([#418](https://github.com/PyO3/rust-numpy/pull/418)) ([#419](https://github.com/PyO3/rust-numpy/pull/419)) ([#420](https://github.com/PyO3/rust-numpy/pull/420)) ([#421](https://github.com/PyO3/rust-numpy/pull/421)) ([#422](https://github.com/PyO3/rust-numpy/pull/422))
Expand Down
2 changes: 1 addition & 1 deletion examples/linalg/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

@nox.session
def tests(session):
session.install("pip", "numpy<2", "pytest")
session.install("pip", "numpy", "pytest")
session.run("pip", "install", ".", "-v")
session.run("pytest")
2 changes: 1 addition & 1 deletion examples/parallel/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

@nox.session
def tests(session):
session.install("pip", "numpy<2", "pytest")
session.install("pip", "numpy", "pytest")
session.run("pip", "install", ".", "-v")
session.run("pytest")
2 changes: 1 addition & 1 deletion examples/simple/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

@nox.session
def tests(session):
session.install("pip", "numpy<2", "pytest")
session.install("pip", "numpy", "pytest")
session.run("pip", "install", ".", "-v")
session.run("pytest")
54 changes: 27 additions & 27 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;
use crate::array::get_array_module;
use crate::cold;
use crate::error::BorrowError;
use crate::npyffi::{PyArrayObject, PyArray_Check, NPY_ARRAY_WRITEABLE};
use crate::npyffi::{PyArrayObject, PyArray_Check, PyDataType_ELSIZE, NPY_ARRAY_WRITEABLE};

/// Defines the shared C API used for borrow checking
///
Expand Down Expand Up @@ -48,7 +48,7 @@ unsafe extern "C" fn acquire_shared(flags: *mut c_void, array: *mut PyArrayObjec
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

match flags.acquire(address, key) {
Ok(()) => 0,
Expand All @@ -66,7 +66,7 @@ unsafe extern "C" fn acquire_mut_shared(flags: *mut c_void, array: *mut PyArrayO
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

match flags.acquire_mut(address, key) {
Ok(()) => 0,
Expand All @@ -80,7 +80,7 @@ unsafe extern "C" fn release_shared(flags: *mut c_void, array: *mut PyArrayObjec
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

flags.release(address, key);
}
Expand All @@ -91,7 +91,7 @@ unsafe extern "C" fn release_mut_shared(flags: *mut c_void, array: *mut PyArrayO
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

flags.release_mut(address, key);
}
Expand Down Expand Up @@ -379,8 +379,8 @@ fn base_address<'py>(py: Python<'py>, mut array: *mut PyArrayObject) -> *mut c_v
}
}

fn borrow_key(array: *mut PyArrayObject) -> BorrowKey {
let range = data_range(array);
fn borrow_key<'py>(py: Python<'py>, array: *mut PyArrayObject) -> BorrowKey {
let range = data_range(py, array);

let data_ptr = unsafe { (*array).data };
let gcd_strides = gcd_strides(array);
Expand All @@ -392,7 +392,7 @@ fn borrow_key(array: *mut PyArrayObject) -> BorrowKey {
}
}

fn data_range(array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
fn data_range<'py>(py: Python<'py>, array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
let nd = unsafe { (*array).nd } as usize;
let data = unsafe { (*array).data };

Expand All @@ -403,7 +403,7 @@ fn data_range(array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
let shape = unsafe { from_raw_parts((*array).dimensions as *mut usize, nd) };
let strides = unsafe { from_raw_parts((*array).strides, nd) };

let itemsize = unsafe { (*(*array).descr).elsize } as isize;
let itemsize = unsafe { PyDataType_ELSIZE(py, (*array).descr) } as isize;

let mut start = 0;
let mut end = 0;
Expand Down Expand Up @@ -468,7 +468,7 @@ mod tests {
let base_address = base_address(py, array.as_array_ptr());
assert_eq!(base_address, array.as_ptr().cast());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(6) } as *mut c_char);
});
Expand All @@ -486,7 +486,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(6).cast::<c_char>()
Expand Down Expand Up @@ -517,7 +517,7 @@ mod tests {
assert_ne!(base_address, view.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(4) } as *mut c_char);
});
Expand Down Expand Up @@ -550,7 +550,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(4).cast::<c_char>()
Expand Down Expand Up @@ -600,7 +600,7 @@ mod tests {
assert_ne!(base_address, view1.as_ptr().cast::<c_void>());
assert_eq!(base_address, base as *mut c_void);

let data_range = data_range(view2.as_array_ptr());
let data_range = data_range(py, view2.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(1) } as *mut c_char);
});
Expand Down Expand Up @@ -652,7 +652,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view2.as_array_ptr());
let data_range = data_range(py, view2.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(1).cast::<c_char>()
Expand Down Expand Up @@ -683,7 +683,7 @@ mod tests {
assert_ne!(base_address, view.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(view.data(), unsafe { array.data().offset(2) });
assert_eq!(data_range.0, unsafe { view.data().offset(-2) }
as *mut c_char);
Expand All @@ -703,7 +703,7 @@ mod tests {
let base_address = base_address(py, array.as_array_ptr());
assert_eq!(base_address, array.as_ptr().cast::<c_void>());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, array.data() as *mut c_char);
});
Expand All @@ -721,7 +721,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key1 = borrow_key(view1.as_array_ptr());
let key1 = borrow_key(py, view1.as_array_ptr());

assert_eq!(view1.strides(), &[80, 24]);
assert_eq!(key1.gcd_strides, 8);
Expand All @@ -732,7 +732,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key2 = borrow_key(view2.as_array_ptr());
let key2 = borrow_key(py, view2.as_array_ptr());

assert_eq!(view2.strides(), &[80, 24]);
assert_eq!(key2.gcd_strides, 8);
Expand All @@ -743,7 +743,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key3 = borrow_key(view3.as_array_ptr());
let key3 = borrow_key(py, view3.as_array_ptr());

assert_eq!(view3.strides(), &[80, 16]);
assert_eq!(key3.gcd_strides, 16);
Expand All @@ -754,7 +754,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key4 = borrow_key(view4.as_array_ptr());
let key4 = borrow_key(py, view4.as_array_ptr());

assert_eq!(view4.strides(), &[80, 16]);
assert_eq!(key4.gcd_strides, 16);
Expand All @@ -777,7 +777,7 @@ mod tests {
let base1 = base_address(py, array1.as_array_ptr());
let base2 = base_address(py, array2.as_array_ptr());

let key1 = borrow_key(array1.as_array_ptr());
let key1 = borrow_key(py, array1.as_array_ptr());
let _exclusive1 = array1.readwrite();

{
Expand All @@ -791,7 +791,7 @@ mod tests {
assert_eq!(flag, -1);
}

let key2 = borrow_key(array2.as_array_ptr());
let key2 = borrow_key(py, array2.as_array_ptr());
let _shared2 = array2.readonly();

{
Expand Down Expand Up @@ -827,7 +827,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key1 = borrow_key(view1.as_array_ptr());
let key1 = borrow_key(py, view1.as_array_ptr());
let exclusive1 = view1.readwrite();

{
Expand All @@ -847,7 +847,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key2 = borrow_key(view2.as_array_ptr());
let key2 = borrow_key(py, view2.as_array_ptr());
let shared2 = view2.readonly();

{
Expand All @@ -870,7 +870,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key3 = borrow_key(view3.as_array_ptr());
let key3 = borrow_key(py, view3.as_array_ptr());
let shared3 = view3.readonly();

{
Expand All @@ -896,7 +896,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key4 = borrow_key(view4.as_array_ptr());
let key4 = borrow_key(py, view4.as_array_ptr());
let shared4 = view4.readonly();

{
Expand Down
6 changes: 4 additions & 2 deletions src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ use pyo3::{sync::GILProtected, Bound, Py, Python};
use rustc_hash::FxHashMap;

use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods};
use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES};
use crate::npyffi::{
PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
};

/// Represents the [datetime units][datetime-units] supported by NumPy
///
Expand Down Expand Up @@ -230,7 +232,7 @@ impl TypeDescriptors {

// SAFETY: `self.npy_type` is either `NPY_DATETIME` or `NPY_TIMEDELTA` which implies the type of `c_metadata`.
unsafe {
let metadata = &mut *((*dtype.as_dtype_ptr()).c_metadata
let metadata = &mut *(PyDataType_C_METADATA(py, dtype.as_dtype_ptr())
as *mut PyArray_DatetimeDTypeMetaData);

metadata.meta.base = unit;
Expand Down
Loading