diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 9368f8d4..b7c1c610 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -22,6 +22,7 @@ from cubed.core.array import CoreArray, check_array_specs, compute, gensym from cubed.core.plan import Plan, new_temp_path from cubed.primitive.blockwise import blockwise as primitive_blockwise +from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise from cubed.primitive.rechunk import rechunk as primitive_rechunk from cubed.utils import chunk_memory, get_item, to_chunksize from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks @@ -295,6 +296,60 @@ def blockwise( return Array(name, pipeline.target_array, spec, plan) +def general_blockwise( + func, + block_function, + *arrays, + shape, + dtype, + chunks, + target_store=None, + extra_func_kwargs=None, + **kwargs, +) -> "Array": + assert len(arrays) > 0 + + # replace arrays with zarr arrays + zargs = [a.zarray_maybe_lazy for a in arrays] + in_names = [a.name for a in arrays] + + extra_source_arrays = kwargs.pop("extra_source_arrays", []) + source_arrays = list(arrays) + list(extra_source_arrays) + + extra_projected_mem = kwargs.pop("extra_projected_mem", 0) + + name = gensym() + spec = check_array_specs(arrays) + if target_store is None: + target_store = new_temp_path(name=name, spec=spec) + pipeline = primitive_general_blockwise( + func, + block_function, + *zargs, + allowed_mem=spec.allowed_mem, + reserved_mem=spec.reserved_mem, + extra_projected_mem=extra_projected_mem, + target_store=target_store, + shape=shape, + dtype=dtype, + chunks=chunks, + in_names=in_names, + extra_func_kwargs=extra_func_kwargs, + **kwargs, + ) + plan = Plan._new( + name, + "blockwise", + pipeline.target_array, + pipeline, + False, + *source_arrays, + ) + from cubed.array_api import Array + + return Array(name, pipeline.target_array, spec, plan) + + def elemwise(func, *args: "Array", dtype=None) -> "Array": """Apply a function elementwise to array arguments, respecting broadcasting.""" shapes = [arg.shape for arg in args] diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 7bbfa8f1..8e23c9c2 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import toolz import zarr @@ -15,7 +15,7 @@ from cubed.runtime.types import CubedPipeline from cubed.storage.zarr import T_ZarrArray, lazy_empty from cubed.types import T_Chunks, T_DType, T_Shape, T_Store -from cubed.utils import chunk_memory, get_item, split_into, to_chunksize +from cubed.utils import chunk_memory, get_item, map_nested, split_into, to_chunksize from cubed.vendor.dask.array.core import normalize_chunks from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product from cubed.vendor.dask.core import flatten @@ -40,7 +40,7 @@ class BlockwiseSpec: Attributes ---------- block_function : Callable - A function that maps input chunk indexes to an output chunk index. + A function that maps an output chunk index to one or more input chunk indexes. function : Callable A function that maps input chunks to an output chunk. reads_map : Dict[str, CubedArrayProxy] @@ -62,15 +62,13 @@ def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None: out_chunk_key = key_to_slices( out_key_tuple, config.write.array, config.write.chunks ) + + # get array chunks for input keys, preserving any nested list structure args = [] + get_chunk_config = partial(get_chunk, config=config) name_chunk_inds = config.block_function(("out",) + out_key_tuple) for name_chunk_ind in name_chunk_inds: - name = name_chunk_ind[0] - chunk_ind = name_chunk_ind[1:] - arr = config.reads_map[name].open() - chunk_key = key_to_slices(chunk_ind, arr) - arg = arr[chunk_key] - arg = numpy_array_to_backend_array(arg) + arg = map_nested(get_chunk_config, name_chunk_ind) args.append(arg) result = config.function(*args) @@ -91,6 +89,17 @@ def key_to_slices( return get_item(chunks, key) +def get_chunk(name_chunk_ind, config): + """Read a chunk from the named array""" + name = name_chunk_ind[0] + chunk_ind = name_chunk_ind[1:] + arr = config.reads_map[name].open() + chunk_key = key_to_slices(chunk_ind, arr) + arg = arr[chunk_key] + arg = numpy_array_to_backend_array(arg) + return arg + + def blockwise( func: Callable[..., Any], out_ind: Sequence[Union[str, int]], @@ -110,6 +119,9 @@ def blockwise( ): """Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules. + Unlike ```general_blockwise``, an index notation is used to specify the block mapping, + like in Dask Array. + Parameters ---------- func : callable @@ -146,10 +158,8 @@ def blockwise( CubedPipeline to run the operation """ - # Use dask's make_blockwise_graph arrays: Sequence[T_ZarrArray] = args[::2] array_names = in_names or [f"in_{i}" for i in range(len(arrays))] - array_map = {name: array for name, array in zip(array_names, arrays)} inds: Sequence[Union[str, int]] = args[1::2] @@ -164,11 +174,6 @@ def blockwise( for name, ind in zip(array_names, inds): argindsstr.extend((name, ind)) - # TODO: check output shape and chunks are consistent with inputs - chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) - - # block func - block_function = make_blockwise_function_flattened( func, out_name or "out", @@ -178,28 +183,78 @@ def blockwise( new_axes=new_axes, ) - output_blocks_generator_fn = partial( - get_output_blocks, + return general_blockwise( func, - out_name or "out", - out_ind, - *argindsstr, - numblocks=numblocks, - new_axes=new_axes, + block_function, + *arrays, + allowed_mem=allowed_mem, + reserved_mem=reserved_mem, + target_store=target_store, + shape=shape, + dtype=dtype, + chunks=chunks, + in_names=in_names, + extra_projected_mem=extra_projected_mem, + extra_func_kwargs=extra_func_kwargs, + **kwargs, ) - output_blocks = IterableFromGenerator(output_blocks_generator_fn) - num_tasks = num_output_blocks( - func, - out_name or "out", - out_ind, - *argindsstr, - numblocks=numblocks, - new_axes=new_axes, - ) - # end block func +def general_blockwise( + func: Callable[..., Any], + block_function: Callable[..., Any], + *arrays: Any, + allowed_mem: int, + reserved_mem: int, + target_store: T_Store, + shape: T_Shape, + dtype: T_DType, + chunks: T_Chunks, + in_names: Optional[List[str]] = None, + extra_projected_mem: int = 0, + extra_func_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +): + """A more general form of ``blockwise`` that uses a function to specify the block + mapping, rather than an index notation. + + Parameters + ---------- + func : callable + Function to apply to individual tuples of blocks + block_function : callable + A function that maps an output chunk index to one or more input chunk indexes. + *arrays : sequence of Array + The input arrays. + allowed_mem : int + The memory available to a worker for running a task, in bytes. Includes ``reserved_mem``. + reserved_mem : int + The memory reserved on a worker for non-data use when running a task, in bytes + target_store : string or zarr.Array + Path to output Zarr store, or Zarr array + shape : tuple + The shape of the output array. + dtype : np.dtype + The ``dtype`` of the output array. + chunks : tuple + The chunks of the output array. + extra_projected_mem : int + Extra memory projected to be needed (in bytes) in addition to the memory used reading + the input arrays and writing the output. + extra_func_kwargs : dict + Extra keyword arguments to pass to function that can't be passed as regular keyword arguments + since they clash with other blockwise arguments (such as dtype). + **kwargs : dict + Extra keyword arguments to pass to function + + Returns + ------- + CubedPipeline to run the operation + """ + array_names = in_names or [f"in_{i}" for i in range(len(arrays))] + array_map = {name: array for name, array in zip(array_names, arrays)} + chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) chunksize = to_chunksize(chunks) if isinstance(target_store, zarr.Array): target_array = target_store @@ -236,6 +291,10 @@ def blockwise( f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})" ) + # this must be an iterator of lists, not of tuples, otherwise lithops breaks + output_blocks = map(list, itertools.product(*[range(len(c)) for c in chunks])) + num_tasks = math.prod(len(c) for c in chunks) + return CubedPipeline( apply_blockwise, gensym("apply_blockwise"), @@ -474,50 +533,3 @@ def blockwise_fn_flattened(out_key): return name_chunk_inds return blockwise_fn_flattened - - -def get_output_blocks( - func: Callable[..., Any], - output: str, - out_indices: Sequence[Union[str, int]], - *arrind_pairs: Any, - numblocks: Optional[Dict[str, Tuple[int, ...]]] = None, - new_axes: Optional[Dict[int, int]] = None, -) -> Iterator[List[int]]: - if numblocks is None: - raise ValueError("Missing required numblocks argument.") - new_axes = new_axes or {} - argpairs = list(toolz.partition(2, arrind_pairs)) - - # Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions - dims = _make_dims(argpairs, numblocks, new_axes) - - # return a list of lists, not of tuples, otherwise lithops breaks - for tup in itertools.product(*[range(dims[i]) for i in out_indices]): - yield list(tup) - - -class IterableFromGenerator: - def __init__(self, generator_fn: Callable[[], Iterator[List[int]]]): - self.generator_fn = generator_fn - - def __iter__(self): - return self.generator_fn() - - -def num_output_blocks( - func: Callable[..., Any], - output: str, - out_indices: Sequence[Union[str, int]], - *arrind_pairs: Any, - numblocks: Optional[Dict[str, Tuple[int, ...]]] = None, - new_axes: Optional[Dict[int, int]] = None, -) -> int: - if numblocks is None: - raise ValueError("Missing required numblocks argument.") - new_axes = new_axes or {} - argpairs = list(toolz.partition(2, arrind_pairs)) - - # Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions - dims = _make_dims(argpairs, numblocks, new_axes) - return math.prod(dims[i] for i in out_indices) diff --git a/cubed/tests/primitive/test_blockwise.py b/cubed/tests/primitive/test_blockwise.py index d1af9caf..0427e5ca 100644 --- a/cubed/tests/primitive/test_blockwise.py +++ b/cubed/tests/primitive/test_blockwise.py @@ -4,7 +4,11 @@ from numpy.testing import assert_array_equal from cubed.backend_array_api import namespace as nxp -from cubed.primitive.blockwise import blockwise, make_blockwise_function +from cubed.primitive.blockwise import ( + blockwise, + general_blockwise, + make_blockwise_function, +) from cubed.runtime.executors.python import PythonDagExecutor from cubed.tests.utils import create_zarr, execute_pipeline from cubed.vendor.dask.blockwise import make_blockwise_graph @@ -162,6 +166,64 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem): ) +def test_general_blockwise(tmp_path, executor): + source = create_zarr( + list(range(20)), + dtype=int, + chunks=(2,), + store=tmp_path / "source.zarr", + ) + allowed_mem = 1000 + target_store = tmp_path / "target.zarr" + + numblocks = 10 + in_name = "x" + merge_factor = 3 + + def merge_chunks(xs): + return nxp.concat(xs, axis=0) + + def block_function(out_key): + out_coords = out_key[1:] + + k = merge_factor + out_coord = out_coords[0] # this is just 1d + # return a tuple with a single item that is the list of input keys to be merged + return ( + [ + (in_name, out_coord * k + i) + for i in range(k) + if out_coord * k + i < numblocks + ], + ) + + pipeline = general_blockwise( + merge_chunks, + block_function, + source, + allowed_mem=allowed_mem, + reserved_mem=0, + target_store=target_store, + shape=(20,), + dtype=int, + chunks=(6,), + in_names=[in_name], + ) + + assert pipeline.target_array.shape == (20,) + assert pipeline.target_array.dtype == int + assert pipeline.target_array.chunks == (6,) + + assert pipeline.num_tasks == 4 + + pipeline.target_array.create() # create lazy zarr array + + execute_pipeline(pipeline, executor=executor) + + res = zarr.open_array(target_store) + assert_array_equal(res[:], np.arange(20)) + + def test_make_blockwise_function_map(): func = lambda x: 0 diff --git a/cubed/tests/test_utils.py b/cubed/tests/test_utils.py index b9a46204..4c1b0d6d 100644 --- a/cubed/tests/test_utils.py +++ b/cubed/tests/test_utils.py @@ -8,6 +8,7 @@ chunk_memory, extract_stack_summaries, join_path, + map_nested, memory_repr, peak_measured_mem, split_into, @@ -77,3 +78,62 @@ def test_split_into(): assert list(split_into([1, 2, 3, 4, 5, 6], [1, 2, 3])) == [[1], [2, 3], [4, 5, 6]] assert list(split_into([1, 2, 3, 4, 5, 6], [2, 3])) == [[1, 2], [3, 4, 5]] assert list(split_into([1, 2, 3, 4], [1, 2, 3, 4])) == [[1], [2, 3], [4], []] + + +def test_map_nested_lists(): + inc = lambda x: x + 1 + + assert map_nested(inc, [1, 2]) == [2, 3] + assert map_nested(inc, [[1, 2]]) == [[2, 3]] + assert map_nested(inc, [[1, 2], [3, 4]]) == [[2, 3], [4, 5]] + + +count = 0 + + +def inc(x): + global count + count = count + 1 + return x + 1 + + +def test_map_nested_iterators(): + # same tests as test_map_nested_lists, but use a counter to check that iterators are advanced at correct points + global count + + out = map_nested(inc, iter([1, 2])) + assert isinstance(out, map) + assert count == 0 + assert list(out) == [2, 3] + assert count == 2 + + # reset count + count = 0 + + out = map_nested(inc, [iter([1, 2])]) + assert isinstance(out, list) + assert count == 0 + assert len(out) == 1 + out = out[0] + assert isinstance(out, map) + assert count == 0 + assert list(out) == [2, 3] + assert count == 2 + + # reset count + count = 0 + + out = map_nested(inc, [iter([1, 2]), iter([3, 4])]) + assert isinstance(out, list) + assert count == 0 + assert len(out) == 2 + out0 = out[0] + assert isinstance(out0, map) + assert count == 0 + assert list(out0) == [2, 3] + assert count == 2 + out1 = out[1] + assert isinstance(out1, map) + assert count == 2 + assert list(out1) == [4, 5] + assert count == 4 diff --git a/cubed/utils.py b/cubed/utils.py index 9cc66b89..819247bc 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -4,6 +4,7 @@ import sys import sysconfig import traceback +from collections.abc import Iterator from dataclasses import dataclass from itertools import islice from math import prod @@ -252,3 +253,29 @@ def split_into(iterable, sizes): it = iter(iterable) for size in sizes: yield list(islice(it, size)) + + +def map_nested(func, seq): + """Apply a function inside nested lists or iterators, while preserving + the nesting, and the collection or iterator type. + + Examples + -------- + + >>> from cubed.utils import map_nested + >>> inc = lambda x: x + 1 + >>> map_nested(inc, [[1, 2], [3, 4]]) + [[2, 3], [4, 5]] + + >>> it = map_nested(inc, iter([1, 2])) + >>> next(it) + 2 + >>> next(it) + 3 + """ + if isinstance(seq, list): + return [map_nested(func, item) for item in seq] + elif isinstance(seq, Iterator): + return map(lambda item: map_nested(func, item), seq) + else: + return func(seq)