Skip to content

Commit

Permalink
Run tests using JAX as the backend array API (on CPU)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 9, 2024
1 parent 9cf7037 commit 02987ba
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 6 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: JAX tests

on:
schedule:
# Every weekday at 03:53 UTC, see https://crontab.guru/
- cron: "53 3 * * 1-5"
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.9"]

steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1

- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'jax[cpu]'
python -m pip uninstall -y lithops # tests don't run on Lithops
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
33 changes: 29 additions & 4 deletions cubed/backend_array_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
import os
from importlib import import_module

import numpy as np

# The array implementation used for backend operations.
# This must be compatible with the Python Array API standard, although
# The array implementation used for backend operations is stored in the
# namespace variable, and defaults to array_api_compat.nump, unless it
# is overridden by an environment variable.
# It must be compatible with the Python Array API standard, although
# some extra functions are used too (nan functions, take_along_axis),
# which array_api_compat provides, but other Array API implementations
# may not.
import array_api_compat.numpy # noqa: F401 isort:skip

namespace = array_api_compat.numpy
if "CUBED_BACKEND_ARRAY_API_MODULE" in os.environ:
# This code is based on similar code in array_api_tests
xp_name = os.environ["CUBED_BACKEND_ARRAY_API_MODULE"]
_module, _sub = xp_name, None
if "." in xp_name:
_module, _sub = xp_name.split(".", 1)
xp = import_module(_module)
if _sub:
try:
xp = getattr(xp, _sub)
except AttributeError:
# _sub may be a submodule that needs to be imported. WE can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
xp = import_module(xp_name)
namespace = xp

else:
import array_api_compat.numpy

namespace = array_api_compat.numpy


# These functions to convert to/from backend arrays
# assume that no extra memory is allocated, by using the
Expand Down
5 changes: 3 additions & 2 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import zarr
from zarr.indexing import BasicIndexer, is_slice

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
Expand Down Expand Up @@ -107,7 +108,7 @@ class VirtualInMemoryArray:

def __init__(
self,
array: np.ndarray, # TODO: generalise
array: np.ndarray, # TODO: generalise to array API type
chunks: T_RegularChunks,
max_nbytes: int = 10**6,
):
Expand All @@ -129,7 +130,7 @@ def __init__(
self.chunks = template.chunks
self.template = template
if array.size > 0:
template[...] = array
template[...] = backend_array_to_numpy_array(array)

def __getitem__(self, key):
return self.array.__getitem__(key)
Expand Down

0 comments on commit 02987ba

Please sign in to comment.