Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding Prototype II: implementation on Array #947

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions sharding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import os

import zarr

store = zarr.DirectoryStore("data/chunking_test.zarr")
z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None)
z[:10, :] = 42
z[15, 1] = 389
z[19, 2] = 1
z[0, 1] = -4.2

print(store[".zarray"].decode())
# {
# "chunks": [
# 3,
# 2
# ],
# "compressor": null,
# "dtype": "<f8",
# "fill_value": 0.0,
# "filters": null,
# "order": "C",
# "shape": [
# 20,
# 3
# ],
# "shard_format": "indexed",
# "shards": [
# 2,
# 2
# ],
# "zarr_format": 2
# }

assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2]

print("ONDISK", sorted(os.listdir("data/chunking_test.zarr")))
print("STORE", sorted(store))
print("CHUNKSTORE (SHARDED)", sorted(z.chunk_store))

# ONDISK ['.zarray', '0.0', '1.0', '2.0', '3.0']
# STORE ['.zarray', '0.0', '1.0', '2.0', '3.0']
# CHUNKSTORE (SHARDED) ['.zarray', '0.0', '1.0', '2.0', '3.0']

index_bytes = z.store["0.0"][-2*2*16:]
print("INDEX 0.0", [int.from_bytes(index_bytes[i:i+8], byteorder="little") for i in range(0, len(index_bytes), 8)])
# INDEX 0.0 [0, 48, 48, 48, 96, 48, 144, 48]

z_reopened = zarr.open("data/chunking_test.zarr")
assert z_reopened.shards == (2, 2)
assert z_reopened[15, 1] == 389
assert z_reopened[19, 2] == 1
assert z_reopened[0, 1] == -4.2
assert z_reopened[0, 0] == 42
2 changes: 2 additions & 0 deletions zarr/_storage/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _ensure_store(store: Any):


class Store(BaseStore):
# TODO: document methods which allow optimizations,
# e.g. delitems, setitems, getitems, listdir, …
"""Abstract store class used by implementations following the Zarr v2 spec.

Adds public `listdir`, `rename`, and `rmdir` methods on top of BaseStore.
Expand Down
186 changes: 172 additions & 14 deletions zarr/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import binascii
from collections import defaultdict
import hashlib
import itertools
import math
import operator
import re
from functools import reduce
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, NamedTuple, Union

import numpy as np
from numcodecs.compat import ensure_bytes, ensure_ndarray
Expand Down Expand Up @@ -47,6 +49,55 @@
)


MAX_UINT_64 = 2 ** 64 - 1

class _ShardIndex(NamedTuple):
array: "Array"
offsets_and_lengths: np.ndarray # dtype uint64, shape (shards_0, _shards_1, ..., 2)

def __localize_chunk__(self, chunk: Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(chunk_i % shard_i for chunk_i, shard_i in zip(chunk, self.array._shards))

def get_chunk_slice(self, chunk: Tuple[int, ...]) -> Optional[slice]:
localized_chunk = self.__localize_chunk__(chunk)
chunk_start, chunk_len = self.offsets_and_lengths[localized_chunk]
if (chunk_start, chunk_len) == (MAX_UINT_64, MAX_UINT_64):
return None
else:
return slice(chunk_start, chunk_start + chunk_len)

def set_chunk_slice(self, chunk: Tuple[int, ...], chunk_slice: Optional[slice]) -> None:
localized_chunk = self.__localize_chunk__(chunk)
if chunk_slice is None:
self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64)
else:
self.offsets_and_lengths[localized_chunk] = (
chunk_slice.start,
chunk_slice.stop - chunk_slice.start
)

def to_bytes(self) -> bytes:
return self.offsets_and_lengths.tobytes(order='C')

@classmethod
def from_bytes(
cls, buffer: Union[bytes, bytearray], array: "Array"
) -> "_ShardIndex":
return cls(
array=array,
offsets_and_lengths=np.frombuffer(
bytearray(buffer), dtype="<u8"
).reshape(*array._shards, 2, order="C")
)

@classmethod
def create_empty(cls, array: "Array"):
# reserving 2*64bit per chunk for offset and length:
return cls.from_bytes(
MAX_UINT_64.to_bytes(8, byteorder="little") * (2 * array._num_chunks_per_shard),
array=array
)

# noinspection PyUnresolvedReferences
class Array:
"""Instantiate an array from an initialized store.
Expand Down Expand Up @@ -191,6 +242,9 @@ def __init__(
self._oindex = OIndex(self)
self._vindex = VIndex(self)

# the sharded store is only initialized when needed
self._cached_sharded_store = None

def _load_metadata(self):
"""(Re)load metadata from store."""
if self._synchronizer is None:
Expand All @@ -213,6 +267,8 @@ def _load_metadata_nosync(self):
self._meta = meta
self._shape = meta['shape']
self._chunks = meta['chunks']
self._shards = meta.get('shards')
self._shard_format = meta.get('shard_format')
self._dtype = meta['dtype']
self._fill_value = meta['fill_value']
self._order = meta['order']
Expand Down Expand Up @@ -262,9 +318,12 @@ def _flush_metadata_nosync(self):
filters_config = [f.get_config() for f in self._filters]
else:
filters_config = None
# Possible (unrelated) bug:
# should the dimension_separator also be included in this dict?
meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype,
compressor=compressor_config, fill_value=self._fill_value,
order=self._order, filters=filters_config)
order=self._order, filters=filters_config,
shards=self._shards, shard_format=self._shard_format)
mkey = self._key_prefix + array_meta_key
self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta)

Expand Down Expand Up @@ -327,11 +386,17 @@ def shape(self, value):
self.resize(value)

@property
def chunks(self):
def chunks(self) -> Optional[Tuple[int, ...]]:
"""A tuple of integers describing the length of each dimension of a
chunk of the array."""
chunk of the array, or None."""
return self._chunks

@property
def shards(self):
"""A tuple of integers describing the number of chunks in each shard
of the array."""
return self._shards

@property
def dtype(self):
"""The NumPy data type."""
Expand Down Expand Up @@ -487,6 +552,38 @@ def write_empty_chunks(self) -> bool:
"""
return self._write_empty_chunks

@property
def _num_chunks_per_shard(self) -> int:
return reduce(lambda x, y: x*y, self._shards, 1)

def __keys_to_shard_groups__(
self, keys: Iterable[str]
) -> Dict[str, List[Tuple[str, Tuple[int, ...]]]]:
shard_indices_per_shard_key = defaultdict(list)
for chunk_key in keys:
# TODO: allow to be in a group (aka only use last parts for dimensions)
chunk_subkeys = tuple(map(int, chunk_key.split(self._dimension_separator)))
shard_key_tuple = (
subkey // shard_i for subkey, shard_i in zip(chunk_subkeys, self._shards)
)
shard_key = self._dimension_separator.join(map(str, shard_key_tuple))
shard_indices_per_shard_key[shard_key].append((chunk_key, chunk_subkeys))
return shard_indices_per_shard_key

def __get_index__(self, buffer: Union[bytes, bytearray]) -> _ShardIndex:
# At the end of each shard 2*64bit per chunk for offset and length define the index:
return _ShardIndex.from_bytes(buffer[-16 * self._num_chunks_per_shard:], self)

def __get_chunks_in_shard(self, shard_key: str) -> Iterator[Tuple[int, ...]]:
# TODO: allow to be in a group (aka only use last parts for dimensions)
shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator)))
for chunk_offset in itertools.product(*(range(i) for i in self._shards)):
yield tuple(
shard_key_i * shards_i + offset_i
for shard_key_i, offset_i, shards_i
in zip(shard_key_tuple, chunk_offset, self._shards)
)

def __eq__(self, other):
return (
isinstance(other, Array) and
Expand Down Expand Up @@ -857,7 +954,7 @@ def _get_basic_selection_zd(self, selection, out=None, fields=None):
try:
# obtain encoded data for chunk
ckey = self._chunk_key((0,))
cdata = self.chunk_store[ckey]
cdata = self._read_single_possibly_sharded(ckey)

except KeyError:
# chunk not initialized
Expand Down Expand Up @@ -1170,8 +1267,10 @@ def _get_selection(self, indexer, out=None, fields=None):
check_array_shape('out', out, out_shape)

# iterate over chunks
if not hasattr(self.chunk_store, "getitems") or \
any(map(lambda x: x == 0, self.shape)):
if self._shards is None and (
not hasattr(self.chunk_store, "getitems")
or any(map(lambda x: x == 0, self.shape))
):
# sequentially get one key at a time from storage
for chunk_coords, chunk_selection, out_selection in indexer:

Expand Down Expand Up @@ -1640,7 +1739,7 @@ def _set_basic_selection_zd(self, selection, value, fields=None):
# setup chunk
try:
# obtain compressed data for chunk
cdata = self.chunk_store[ckey]
cdata = self._read_single_possibly_sharded(ckey)

except KeyError:
# chunk not initialized
Expand Down Expand Up @@ -1708,8 +1807,11 @@ def _set_selection(self, indexer, value, fields=None):
check_array_shape('value', value, sel_shape)

# iterate over chunks in range
if not hasattr(self.store, "setitems") or self._synchronizer is not None \
or any(map(lambda x: x == 0, self.shape)):
if self._shards is None and (
not hasattr(self.chunk_store, "setitems")
or self._synchronizer is not None
or any(map(lambda x: x == 0, self.shape))
):
# iterative approach
for chunk_coords, chunk_selection, out_selection in indexer:

Expand Down Expand Up @@ -1883,6 +1985,21 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection,
self._process_chunk(out, cdata, chunk_selection, drop_axes,
out_is_ndarray, fields, out_selection)

def _read_single_possibly_sharded(self, ckey):
if self._shards is None:
return self.chunk_store[ckey]
else:
shard_key, chunks_in_shard = next(iter(self.__keys_to_shard_groups__([ckey]).items()))
# TODO use partial read if available
full_shard_value = self.chunk_store[shard_key]
index = self.__get_index__(full_shard_value)
for _chunk_key, chunk_subkeys in chunks_in_shard:
chunk_slice = index.get_chunk_slice(chunk_subkeys)
if chunk_slice is None:
raise KeyError
else:
return full_shard_value[chunk_slice]

def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
drop_axes=None, fields=None):
"""As _chunk_getitem, but for lists of chunks
Expand All @@ -1904,6 +2021,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
and hasattr(self._compressor, "decode_partial")
and not fields
and self.dtype != object
# TODO: this should rather check for read_block or similar
and hasattr(self.chunk_store, "getitems")
):
partial_read_decode = True
Expand All @@ -1914,7 +2032,16 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
}
else:
partial_read_decode = False
cdatas = self.chunk_store.getitems(ckeys, on_error="omit")
cdatas = {}
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(ckeys).items():
# TODO use partial read if available
full_shard_value = self.chunk_store[shard_key]
index = self.__get_index__(full_shard_value)
for chunk_key, chunk_subkeys in chunks_in_shard:
chunk_slice = index.get_chunk_slice(chunk_subkeys)
if chunk_slice is not None:
cdatas[chunk_key] = full_shard_value[chunk_slice]

for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection):
if ckey in cdatas:
self._process_chunk(
Expand Down Expand Up @@ -1948,11 +2075,40 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None):
to_store = {k: self._encode_chunk(cdatas[k]) for k in nonempty_keys}
else:
to_store = {k: self._encode_chunk(v) for k, v in cdatas.items()}
self.chunk_store.setitems(to_store)

for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(to_store.keys()).items():
all_chunks = set(self.__get_chunks_in_shard(shard_key))
chunks_to_set = set(chunk_subkeys for _chunk_key, chunk_subkeys in chunks_in_shard)
chunks_to_read = all_chunks - chunks_to_set
new_content = {
chunk_subkeys: to_store[chunk_key] for chunk_key, chunk_subkeys in chunks_in_shard
}
try:
# TODO use partial read if available
full_shard_value = self.chunk_store[shard_key]
except KeyError:
index = _ShardIndex.create_empty(self)
else:
index = self.__get_index__(full_shard_value)
for chunk_to_read in chunks_to_read:
chunk_slice = index.get_chunk_slice(chunk_to_read)
if chunk_slice is not None:
new_content[chunk_to_read] = full_shard_value[chunk_slice]

# TODO use partial write if available and possible (e.g. at the end)
shard_content = b""
# TODO: order the chunks in the shard:
for chunk_subkeys, chunk_content in new_content.items():
chunk_slice = slice(len(shard_content), len(shard_content) + len(chunk_content))
index.set_chunk_slice(chunk_subkeys, chunk_slice)
shard_content += chunk_content
# Appending the index at the end of the shard:
shard_content += index.to_bytes()
self.chunk_store[shard_key] = shard_content

def _chunk_delitems(self, ckeys):
if hasattr(self.store, "delitems"):
self.store.delitems(ckeys)
if hasattr(self.chunk_store, "delitems"):
self.chunk_store.delitems(ckeys)
else: # pragma: no cover
# exempting this branch from coverage as there are no extant stores
# that will trigger this condition, but it's possible that they
Expand Down Expand Up @@ -2028,7 +2184,7 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None):
try:

# obtain compressed data for chunk
cdata = self.chunk_store[ckey]
cdata = self._read_single_possibly_sharded(ckey)

except KeyError:

Expand Down Expand Up @@ -2239,6 +2395,7 @@ def digest(self, hashname="sha1"):

h = hashlib.new(hashname)

# TODO: operate on shards here if available:
for i in itertools.product(*[range(s) for s in self.cdata_shape]):
h.update(self.chunk_store.get(self._chunk_key(i), b""))

Expand Down Expand Up @@ -2365,6 +2522,7 @@ def _resize_nosync(self, *args):
except KeyError:
# chunk not initialized
pass
# TODO: collect all chunks do delete and use _chunk_delitems

def append(self, data, axis=0):
"""Append `data` to `axis`.
Expand Down
Loading