-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Organize, document, and simplify transforms.py #726
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,12 +63,13 @@ class MergeDim(CombineDim): | |
operation: ClassVar[CombineOp] = CombineOp.MERGE | ||
|
||
|
||
def augment_index_with_start_stop( | ||
def augment_index_with_byte_range( | ||
position: Position, | ||
item_lens: List[int], | ||
append_offset: int = 0, | ||
) -> IndexedPosition: | ||
"""Take an index _without_ start / stop and add them based on the lens defined in sequence_lens. | ||
"""Take an index _without_ start / stop (byte range) and add them based on the lens defined in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above re: bytes. |
||
sequence_lens. | ||
|
||
:param index: The ``DimIndex`` instance to augment. | ||
:param item_lens: A list of integer lengths for all items in the sequence. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,10 @@ | |
import zarr | ||
from kerchunk.combine import MultiZarrToZarr | ||
|
||
from .aggregation import XarraySchema, dataset_to_schema, schema_to_template_ds, schema_to_zarr | ||
from .aggregation import dataset_to_schema, schema_to_template_ds, schema_to_zarr | ||
from .combiners import CombineXarraySchemas, MinMaxCountCombineFn | ||
from .openers import open_url, open_with_kerchunk, open_with_xarray | ||
from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_start_stop | ||
from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_byte_range | ||
from .rechunking import combine_fragments, consolidate_dimension_coordinates, split_fragment | ||
from .storage import CacheFSSpecTarget, FSSpecTarget | ||
from .types import Indexed | ||
|
@@ -241,33 +241,20 @@ def expand(self, pcoll): | |
) | ||
|
||
|
||
@dataclass | ||
class DatasetToSchema(beam.PTransform): | ||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
return pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v))) | ||
|
||
|
||
def _nest_dim(item: Indexed[T], dimension: Dimension) -> Indexed[Indexed[T]]: | ||
"""Nest dimensions to support multiple combine dimensions""" | ||
index, value = item | ||
inner_index = Index({dimension: index[dimension]}) | ||
outer_index = Index({dk: index[dk] for dk in index if dk != dimension}) | ||
return outer_index, (inner_index, value) | ||
|
||
|
||
@dataclass | ||
class _NestDim(beam.PTransform): | ||
"""Prepare a collection for grouping by transforming an Index into a nested | ||
Tuple of Indexes. | ||
|
||
:param dimension: The dimension to nest | ||
""" | ||
|
||
dimension: Dimension | ||
|
||
def expand(self, pcoll): | ||
return pcoll | beam.Map(_nest_dim, dimension=self.dimension) | ||
|
||
|
||
@dataclass | ||
class DatasetToSchema(beam.PTransform): | ||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
return pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v))) | ||
|
||
|
||
@dataclass | ||
class DetermineSchema(beam.PTransform): | ||
"""Combine many Datasets into a single schema along multiple dimensions. | ||
|
@@ -280,47 +267,35 @@ class DetermineSchema(beam.PTransform): | |
|
||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
schemas = pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v))) | ||
cdims = self.combine_dims.copy() | ||
while len(cdims) > 0: | ||
last_dim = cdims.pop() | ||
if len(cdims) == 0: | ||
# at this point, we should have a 1D index as our key | ||
schemas = schemas | beam.CombineGlobally(CombineXarraySchemas(last_dim)) | ||
else: | ||
# Recursively combine schema definitions | ||
for i, dim in enumerate(self.combine_dims[::-1]): | ||
if i < len(self.combine_dims) - 1: | ||
schemas = ( | ||
schemas | ||
| f"Nest {last_dim.name}" >> _NestDim(last_dim) | ||
| f"Combine {last_dim.name}" | ||
>> beam.CombinePerKey(CombineXarraySchemas(last_dim)) | ||
| f"Nest {dim.name}" >> beam.Map(_nest_dim, dimension=dim) | ||
| f"Combine {dim.name}" >> beam.CombinePerKey(CombineXarraySchemas(dim)) | ||
) | ||
else: # at this point, we should have a 1D index as our key | ||
schemas = schemas | beam.CombineGlobally(CombineXarraySchemas(dim)) | ||
return schemas | ||
|
||
|
||
@dataclass | ||
class IndexItems(beam.PTransform): | ||
class IndexWithPosition(beam.DoFn): | ||
"""Augment dataset indexes with information about start and stop position.""" | ||
|
||
schema: beam.PCollection | ||
append_offset: int = 0 | ||
def __init__(self, append_offset=0): | ||
self.append_offset = append_offset | ||
|
||
@staticmethod | ||
def index_item(item: Indexed[T], schema: XarraySchema, append_offset: int) -> Indexed[T]: | ||
index, ds = item | ||
def process(self, element, schema): | ||
index, ds = element | ||
new_index = Index() | ||
for dimkey, dimval in index.items(): | ||
if dimkey.operation == CombineOp.CONCAT: | ||
item_len_dict = schema["chunks"][dimkey.name] | ||
item_lens = [item_len_dict[n] for n in range(len(item_len_dict))] | ||
dimval = augment_index_with_start_stop(dimval, item_lens, append_offset) | ||
dimval = augment_index_with_byte_range(dimval, item_lens, self.append_offset) | ||
new_index[dimkey] = dimval | ||
return new_index, ds | ||
|
||
def expand(self, pcoll: beam.PCollection): | ||
return pcoll | beam.Map( | ||
self.index_item, | ||
schema=beam.pvalue.AsSingleton(self.schema), | ||
append_offset=self.append_offset, | ||
) | ||
yield new_index, ds | ||
|
||
|
||
@dataclass | ||
|
@@ -331,10 +306,6 @@ class PrepareZarrTarget(beam.PTransform): | |
Note that the dimension coordinates will be initialized with dummy values. | ||
|
||
:param target: Where to store the target Zarr dataset. | ||
:param target_chunks: Dictionary mapping dimension names to chunks sizes. | ||
If a dimension is a not named, the chunks will be inferred from the schema. | ||
If chunking is present in the schema for a given dimension, the length of | ||
the first fragment will be used. Otherwise, the dimension will not be chunked. | ||
:param attrs: Extra group-level attributes to inject into the dataset. | ||
:param encoding: Dictionary describing encoding for xarray.to_zarr() | ||
:param consolidated_metadata: Bool controlling if xarray.to_zarr() | ||
|
@@ -350,22 +321,23 @@ class PrepareZarrTarget(beam.PTransform): | |
""" | ||
|
||
target: str | FSSpecTarget | ||
target_chunks: Dict[str, int] = field(default_factory=dict) | ||
attrs: Dict[str, str] = field(default_factory=dict) | ||
consolidated_metadata: Optional[bool] = True | ||
encoding: Optional[dict] = field(default_factory=dict) | ||
append_dim: Optional[str] = None | ||
|
||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
def expand( | ||
self, pcoll: beam.PCollection, target_chunks: beam.pvalue.AsSingleton | ||
) -> beam.PCollection: | ||
if isinstance(self.target, str): | ||
target = FSSpecTarget.from_url(self.target) | ||
else: | ||
target = self.target | ||
store = target.get_mapper() | ||
initialized_target = pcoll | beam.Map( | ||
initialized_target = pcoll | "initialize zarr store" >> beam.Map( | ||
schema_to_zarr, | ||
target_store=store, | ||
target_chunks=self.target_chunks, | ||
target_chunks=target_chunks, | ||
attrs=self.attrs, | ||
encoding=self.encoding, | ||
consolidated_metadata=False, | ||
|
@@ -374,16 +346,6 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | |
return initialized_target | ||
|
||
|
||
@dataclass | ||
class StoreDatasetFragments(beam.PTransform): | ||
target_store: beam.PCollection # side input | ||
|
||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
return pcoll | beam.Map( | ||
store_dataset_fragment, target_store=beam.pvalue.AsSingleton(self.target_store) | ||
) | ||
|
||
|
||
@dataclass | ||
class ConsolidateMetadata(beam.PTransform): | ||
"""Calls Zarr Python consolidate_metadata on an existing Zarr store | ||
|
@@ -394,22 +356,24 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | |
|
||
|
||
@dataclass | ||
class Rechunk(beam.PTransform): | ||
target_chunks: Optional[Dict[str, int]] | ||
schema: beam.PCollection | ||
|
||
def expand(self, pcoll: beam.PCollection) -> beam.PCollection: | ||
new_fragments = ( | ||
class ChunkToTarget(beam.PTransform): | ||
def expand( | ||
self, | ||
pcoll: beam.PCollection, | ||
target_chunks: beam.pvalue.AsSingleton, | ||
schema: beam.pvalue.AsSingleton, | ||
) -> beam.PCollection: | ||
return ( | ||
pcoll | ||
| beam.FlatMap( | ||
| "key to chunks following schema" | ||
>> beam.FlatMap( | ||
split_fragment, | ||
target_chunks=self.target_chunks, | ||
schema=beam.pvalue.AsSingleton(self.schema), | ||
target_chunks=target_chunks, | ||
schema=schema, | ||
) | ||
| beam.GroupByKey() # this has major performance implication | ||
| beam.MapTuple(combine_fragments) | ||
| "group by write chunk key" >> beam.GroupByKey() # group by key ensures locality | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! this is very helpful developer commentary 😃 |
||
| "per chunk dataset merge" >> beam.MapTuple(combine_fragments) | ||
) | ||
return new_fragments | ||
|
||
|
||
class ConsolidateDimensionCoordinates(beam.PTransform): | ||
|
@@ -647,22 +611,35 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin): | |
based on the full dataset (e.g. divide along a certain dimension based on a desired chunk | ||
size in memory). For more advanced chunking strategies, check | ||
out https://github.com/jbusecke/dynamic_chunks | ||
:param dynamic_chunking_fn_kwargs: Optional keyword arguments for ``dynamic_chunking_fn``. | ||
:param attrs: Extra group-level attributes to inject into the dataset. | ||
:param encoding: Dictionary encoding for xarray.to_zarr(). | ||
:param append_dim: Optional name of the dimension to append to. | ||
|
||
Example of using a wrapper function to reduce the arity of a more complex dynamic_chunking_fn: | ||
|
||
Suppose there's a function `calculate_dynamic_chunks` that requires extra parameters: an | ||
`xarray.Dataset`, a `target_chunk_size` in bytes, and a `dim_name` along which to chunk. | ||
To fit the expected signature for `dynamic_chunking_fn`, we can define a wrapper function | ||
that presets `target_chunk_size` and `dim_name`: | ||
|
||
def calculate_dynamic_chunks(ds, target_chunk_size, dim_name) -> Dict[str, int]: | ||
... | ||
|
||
def dynamic_chunking_wrapper(ds: xarray.Dataset) -> Dict[str, int]: | ||
target_chunk_size = 1024 * 1024 * 10 | ||
dim_name = 'time' | ||
return calculate_dynamic_chunks(ds, target_chunk_size, dim_name) | ||
|
||
StoreToZarr(..., dynamic_chunking_fn=dynamic_chunking_wrapper, ...) | ||
""" | ||
|
||
# TODO: make it so we don't have to explicitly specify combine_dims | ||
# Could be inferred from the pattern instead | ||
Comment on lines
-656
to
-657
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we're abandoning this idea? I have had my doubts that it's possible, since it would be a meta-operation over the pipeline. But i've been wrong about what's possible in beam before. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way this reads to me right now, we'd need to pass the same argument to a couple of different transforms (which isn't a bad pattern afaict). To reuse such a value, we'd need to thread the input through as some kind of computed result and I doubt that juice is worth the squeeze I'm not opposed to keeping the TODO open, if that's desirable. I just couldn't quite imagine how this could be facilitated at the level of recipes |
||
combine_dims: List[Dimension] | ||
store_name: str | ||
target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field( | ||
default_factory=RequiredAtRuntimeDefault | ||
) | ||
target_chunks: Dict[str, int] = field(default_factory=dict) | ||
dynamic_chunking_fn: Optional[Callable[[xr.Dataset], dict]] = None | ||
dynamic_chunking_fn_kwargs: Optional[dict] = field(default_factory=dict) | ||
Comment on lines
642
to
-665
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an advantage to lower arity here aside from an aesthetically tidier signature? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tidy signature is much of the benefit I see but also hopefully passing responsibility downstream to pipeline writers/users in a way that can't hide any magic we might be applying. On this conception of things, the user has one thing to worry about: a function that produces the chunking they desire given a template dataset |
||
attrs: Dict[str, str] = field(default_factory=dict) | ||
encoding: Optional[dict] = field(default_factory=dict) | ||
append_dim: Optional[str] = None | ||
|
@@ -688,29 +665,56 @@ def expand( | |
self, | ||
datasets: beam.PCollection[Tuple[Index, xr.Dataset]], | ||
) -> beam.PCollection[zarr.storage.FSStore]: | ||
logger.info(f"Storing Zarr with {self.target_chunks =} to {self.get_full_target()}") | ||
|
||
pipeline = datasets.pipeline | ||
|
||
# build a global xarray schema (i.e. it ranges over all input datasets) | ||
schema = datasets | DetermineSchema(combine_dims=self.combine_dims) | ||
indexed_datasets = datasets | IndexItems(schema=schema, append_offset=self._append_offset) | ||
target_chunks = ( | ||
self.target_chunks | ||
|
||
# Index datasets according to their place within the global schema | ||
indexed_datasets = datasets | beam.ParDo( | ||
IndexWithPosition(append_offset=self._append_offset), | ||
schema=beam.pvalue.AsSingleton(schema), | ||
) | ||
|
||
# either use target chunks or else compute them with provided chunking function | ||
# Make a PColl for target chunks to match the output of mapping a dynamic chunking fn | ||
target_chunks_pcoll = ( | ||
pipeline | "Create target_chunks pcoll" >> beam.Create([self.target_chunks]) | ||
if not self.dynamic_chunking_fn | ||
else beam.pvalue.AsSingleton( | ||
else ( | ||
schema | ||
| beam.Map(schema_to_template_ds) | ||
| beam.Map(self.dynamic_chunking_fn, **self.dynamic_chunking_fn_kwargs) | ||
| "make template dataset" >> beam.Map(schema_to_template_ds) | ||
| "generate chunks dynamically" >> beam.Map(self.dynamic_chunking_fn) | ||
Comment on lines
+688
to
+689
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these labels really help readability! tysm! |
||
) | ||
) | ||
logger.info(f"Storing Zarr with {target_chunks =} to {self.get_full_target()}") | ||
rechunked_datasets = indexed_datasets | Rechunk(target_chunks=target_chunks, schema=schema) | ||
target_store = schema | PrepareZarrTarget( | ||
|
||
# split datasets according to their write-targets (chunks of bytes) | ||
# then combine datasets with shared write targets | ||
# Note that the pipe (|) operator in beam is just sugar for passing into expand | ||
rechunked_datasets = ChunkToTarget().expand( | ||
indexed_datasets, | ||
target_chunks=beam.pvalue.AsSingleton(target_chunks_pcoll), | ||
schema=beam.pvalue.AsSingleton(schema), | ||
) | ||
|
||
target_store = PrepareZarrTarget( | ||
target=self.get_full_target(), | ||
target_chunks=target_chunks, | ||
attrs=self.attrs, | ||
encoding=self.encoding, | ||
append_dim=self.append_dim, | ||
).expand(schema, beam.pvalue.AsSingleton(target_chunks_pcoll)) | ||
|
||
# Actually attempt to write datasets to their target bytes/files | ||
rechunking = rechunked_datasets | "write chunks" >> beam.Map( | ||
store_dataset_fragment, target_store=beam.pvalue.AsSingleton(target_store) | ||
) | ||
n_target_stores = rechunked_datasets | StoreDatasetFragments(target_store=target_store) | ||
|
||
# the last thing we need to do is extract the zarrstore target. To do this *after* | ||
# rechunking, we need to make the dependency on `rechunking` (and its side effects) explicit | ||
singleton_target_store = ( | ||
n_target_stores | ||
rechunking | ||
| beam.combiners.Sample.FixedSizeGlobally(1) | ||
| beam.FlatMap(lambda x: x) # https://stackoverflow.com/a/47146582 | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is indexing logical dimensions in the dataset, not bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me as though this function uses the logical position of the provided index to calculate start/stop in terms of bytes, no? Like, the sum of item lengths calculated prior to it is determined based on logical position but the resulting integer is supposed to be the byte-range start is how i read it