Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Organize, document, and simplify transforms.py #726

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pangeo_forge_recipes/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ class MergeDim(CombineDim):
operation: ClassVar[CombineOp] = CombineOp.MERGE


def augment_index_with_start_stop(
def augment_index_with_byte_range(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is indexing logical dimensions in the dataset, not bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me as though this function uses the logical position of the provided index to calculate start/stop in terms of bytes, no? Like, the sum of item lengths calculated prior to it is determined based on logical position but the resulting integer is supposed to be the byte-range start is how i read it

position: Position,
item_lens: List[int],
append_offset: int = 0,
) -> IndexedPosition:
"""Take an index _without_ start / stop and add them based on the lens defined in sequence_lens.
"""Take an index _without_ start / stop (byte range) and add them based on the lens defined in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above re: bytes.

sequence_lens.

:param index: The ``DimIndex`` instance to augment.
:param item_lens: A list of integer lengths for all items in the sequence.
Expand Down
2 changes: 0 additions & 2 deletions pangeo_forge_recipes/rechunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def split_fragment(
:param fragment: the indexed fragment.
:param target_chunks_and_dims: mapping from dimension name to a tuple of (chunksize, dimsize)
"""

logger.info(f"Splitting {fragment = }, with {target_chunks = } and {schema = }")

if target_chunks is None and schema is None:
Expand All @@ -41,7 +40,6 @@ def split_fragment(
target_chunks = determine_target_chunks(schema, target_chunks, include_all_dims=False)
else:
assert target_chunks is not None

index, ds = fragment

# target_chunks_and_dims contains both the chunk size and global dataset dimension size
Expand Down
192 changes: 98 additions & 94 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import zarr
from kerchunk.combine import MultiZarrToZarr

from .aggregation import XarraySchema, dataset_to_schema, schema_to_template_ds, schema_to_zarr
from .aggregation import dataset_to_schema, schema_to_template_ds, schema_to_zarr
from .combiners import CombineXarraySchemas, MinMaxCountCombineFn
from .openers import open_url, open_with_kerchunk, open_with_xarray
from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_start_stop
from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_byte_range
from .rechunking import combine_fragments, consolidate_dimension_coordinates, split_fragment
from .storage import CacheFSSpecTarget, FSSpecTarget
from .types import Indexed
Expand Down Expand Up @@ -241,33 +241,20 @@ def expand(self, pcoll):
)


@dataclass
class DatasetToSchema(beam.PTransform):
def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v)))


def _nest_dim(item: Indexed[T], dimension: Dimension) -> Indexed[Indexed[T]]:
"""Nest dimensions to support multiple combine dimensions"""
index, value = item
inner_index = Index({dimension: index[dimension]})
outer_index = Index({dk: index[dk] for dk in index if dk != dimension})
return outer_index, (inner_index, value)


@dataclass
class _NestDim(beam.PTransform):
"""Prepare a collection for grouping by transforming an Index into a nested
Tuple of Indexes.

:param dimension: The dimension to nest
"""

dimension: Dimension

def expand(self, pcoll):
return pcoll | beam.Map(_nest_dim, dimension=self.dimension)


@dataclass
class DatasetToSchema(beam.PTransform):
def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v)))


@dataclass
class DetermineSchema(beam.PTransform):
"""Combine many Datasets into a single schema along multiple dimensions.
Expand All @@ -280,47 +267,35 @@ class DetermineSchema(beam.PTransform):

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
schemas = pcoll | beam.MapTuple(lambda k, v: (k, dataset_to_schema(v)))
cdims = self.combine_dims.copy()
while len(cdims) > 0:
last_dim = cdims.pop()
if len(cdims) == 0:
# at this point, we should have a 1D index as our key
schemas = schemas | beam.CombineGlobally(CombineXarraySchemas(last_dim))
else:
# Recursively combine schema definitions
for i, dim in enumerate(self.combine_dims[::-1]):
if i < len(self.combine_dims) - 1:
schemas = (
schemas
| f"Nest {last_dim.name}" >> _NestDim(last_dim)
| f"Combine {last_dim.name}"
>> beam.CombinePerKey(CombineXarraySchemas(last_dim))
| f"Nest {dim.name}" >> beam.Map(_nest_dim, dimension=dim)
| f"Combine {dim.name}" >> beam.CombinePerKey(CombineXarraySchemas(dim))
)
else: # at this point, we should have a 1D index as our key
schemas = schemas | beam.CombineGlobally(CombineXarraySchemas(dim))
return schemas


@dataclass
class IndexItems(beam.PTransform):
class IndexWithPosition(beam.DoFn):
"""Augment dataset indexes with information about start and stop position."""

schema: beam.PCollection
append_offset: int = 0
def __init__(self, append_offset=0):
self.append_offset = append_offset

@staticmethod
def index_item(item: Indexed[T], schema: XarraySchema, append_offset: int) -> Indexed[T]:
index, ds = item
def process(self, element, schema):
index, ds = element
new_index = Index()
for dimkey, dimval in index.items():
if dimkey.operation == CombineOp.CONCAT:
item_len_dict = schema["chunks"][dimkey.name]
item_lens = [item_len_dict[n] for n in range(len(item_len_dict))]
dimval = augment_index_with_start_stop(dimval, item_lens, append_offset)
dimval = augment_index_with_byte_range(dimval, item_lens, self.append_offset)
new_index[dimkey] = dimval
return new_index, ds

def expand(self, pcoll: beam.PCollection):
return pcoll | beam.Map(
self.index_item,
schema=beam.pvalue.AsSingleton(self.schema),
append_offset=self.append_offset,
)
yield new_index, ds


@dataclass
Expand All @@ -331,10 +306,6 @@ class PrepareZarrTarget(beam.PTransform):
Note that the dimension coordinates will be initialized with dummy values.

:param target: Where to store the target Zarr dataset.
:param target_chunks: Dictionary mapping dimension names to chunks sizes.
If a dimension is a not named, the chunks will be inferred from the schema.
If chunking is present in the schema for a given dimension, the length of
the first fragment will be used. Otherwise, the dimension will not be chunked.
:param attrs: Extra group-level attributes to inject into the dataset.
:param encoding: Dictionary describing encoding for xarray.to_zarr()
:param consolidated_metadata: Bool controlling if xarray.to_zarr()
Expand All @@ -350,22 +321,23 @@ class PrepareZarrTarget(beam.PTransform):
"""

target: str | FSSpecTarget
target_chunks: Dict[str, int] = field(default_factory=dict)
attrs: Dict[str, str] = field(default_factory=dict)
consolidated_metadata: Optional[bool] = True
encoding: Optional[dict] = field(default_factory=dict)
append_dim: Optional[str] = None

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
def expand(
self, pcoll: beam.PCollection, target_chunks: beam.pvalue.AsSingleton
) -> beam.PCollection:
if isinstance(self.target, str):
target = FSSpecTarget.from_url(self.target)
else:
target = self.target
store = target.get_mapper()
initialized_target = pcoll | beam.Map(
initialized_target = pcoll | "initialize zarr store" >> beam.Map(
schema_to_zarr,
target_store=store,
target_chunks=self.target_chunks,
target_chunks=target_chunks,
attrs=self.attrs,
encoding=self.encoding,
consolidated_metadata=False,
Expand All @@ -374,16 +346,6 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return initialized_target


@dataclass
class StoreDatasetFragments(beam.PTransform):
target_store: beam.PCollection # side input

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return pcoll | beam.Map(
store_dataset_fragment, target_store=beam.pvalue.AsSingleton(self.target_store)
)


@dataclass
class ConsolidateMetadata(beam.PTransform):
"""Calls Zarr Python consolidate_metadata on an existing Zarr store
Expand All @@ -394,22 +356,24 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection:


@dataclass
class Rechunk(beam.PTransform):
target_chunks: Optional[Dict[str, int]]
schema: beam.PCollection

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
new_fragments = (
class ChunkToTarget(beam.PTransform):
def expand(
self,
pcoll: beam.PCollection,
target_chunks: beam.pvalue.AsSingleton,
schema: beam.pvalue.AsSingleton,
) -> beam.PCollection:
return (
pcoll
| beam.FlatMap(
| "key to chunks following schema"
>> beam.FlatMap(
split_fragment,
target_chunks=self.target_chunks,
schema=beam.pvalue.AsSingleton(self.schema),
target_chunks=target_chunks,
schema=schema,
)
| beam.GroupByKey() # this has major performance implication
| beam.MapTuple(combine_fragments)
| "group by write chunk key" >> beam.GroupByKey() # group by key ensures locality
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! this is very helpful developer commentary 😃

| "per chunk dataset merge" >> beam.MapTuple(combine_fragments)
)
return new_fragments


class ConsolidateDimensionCoordinates(beam.PTransform):
Expand Down Expand Up @@ -647,22 +611,35 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin):
based on the full dataset (e.g. divide along a certain dimension based on a desired chunk
size in memory). For more advanced chunking strategies, check
out https://github.com/jbusecke/dynamic_chunks
:param dynamic_chunking_fn_kwargs: Optional keyword arguments for ``dynamic_chunking_fn``.
:param attrs: Extra group-level attributes to inject into the dataset.
:param encoding: Dictionary encoding for xarray.to_zarr().
:param append_dim: Optional name of the dimension to append to.

Example of using a wrapper function to reduce the arity of a more complex dynamic_chunking_fn:

Suppose there's a function `calculate_dynamic_chunks` that requires extra parameters: an
`xarray.Dataset`, a `target_chunk_size` in bytes, and a `dim_name` along which to chunk.
To fit the expected signature for `dynamic_chunking_fn`, we can define a wrapper function
that presets `target_chunk_size` and `dim_name`:

def calculate_dynamic_chunks(ds, target_chunk_size, dim_name) -> Dict[str, int]:
...

def dynamic_chunking_wrapper(ds: xarray.Dataset) -> Dict[str, int]:
target_chunk_size = 1024 * 1024 * 10
dim_name = 'time'
return calculate_dynamic_chunks(ds, target_chunk_size, dim_name)

StoreToZarr(..., dynamic_chunking_fn=dynamic_chunking_wrapper, ...)
"""

# TODO: make it so we don't have to explicitly specify combine_dims
# Could be inferred from the pattern instead
Comment on lines -656 to -657
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're abandoning this idea? I have had my doubts that it's possible, since it would be a meta-operation over the pipeline. But i've been wrong about what's possible in beam before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this reads to me right now, we'd need to pass the same argument to a couple of different transforms (which isn't a bad pattern afaict). To reuse such a value, we'd need to thread the input through as some kind of computed result and I doubt that juice is worth the squeeze

I'm not opposed to keeping the TODO open, if that's desirable. I just couldn't quite imagine how this could be facilitated at the level of recipes

combine_dims: List[Dimension]
store_name: str
target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field(
default_factory=RequiredAtRuntimeDefault
)
target_chunks: Dict[str, int] = field(default_factory=dict)
dynamic_chunking_fn: Optional[Callable[[xr.Dataset], dict]] = None
dynamic_chunking_fn_kwargs: Optional[dict] = field(default_factory=dict)
Comment on lines 642 to -665
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage to lower arity here aside from an aesthetically tidier signature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tidy signature is much of the benefit I see but also hopefully passing responsibility downstream to pipeline writers/users in a way that can't hide any magic we might be applying. On this conception of things, the user has one thing to worry about: a function that produces the chunking they desire given a template dataset

attrs: Dict[str, str] = field(default_factory=dict)
encoding: Optional[dict] = field(default_factory=dict)
append_dim: Optional[str] = None
Expand All @@ -688,29 +665,56 @@ def expand(
self,
datasets: beam.PCollection[Tuple[Index, xr.Dataset]],
) -> beam.PCollection[zarr.storage.FSStore]:
logger.info(f"Storing Zarr with {self.target_chunks =} to {self.get_full_target()}")

pipeline = datasets.pipeline

# build a global xarray schema (i.e. it ranges over all input datasets)
schema = datasets | DetermineSchema(combine_dims=self.combine_dims)
indexed_datasets = datasets | IndexItems(schema=schema, append_offset=self._append_offset)
target_chunks = (
self.target_chunks

# Index datasets according to their place within the global schema
indexed_datasets = datasets | beam.ParDo(
IndexWithPosition(append_offset=self._append_offset),
schema=beam.pvalue.AsSingleton(schema),
)

# either use target chunks or else compute them with provided chunking function
# Make a PColl for target chunks to match the output of mapping a dynamic chunking fn
target_chunks_pcoll = (
pipeline | "Create target_chunks pcoll" >> beam.Create([self.target_chunks])
if not self.dynamic_chunking_fn
else beam.pvalue.AsSingleton(
else (
schema
| beam.Map(schema_to_template_ds)
| beam.Map(self.dynamic_chunking_fn, **self.dynamic_chunking_fn_kwargs)
| "make template dataset" >> beam.Map(schema_to_template_ds)
| "generate chunks dynamically" >> beam.Map(self.dynamic_chunking_fn)
Comment on lines +688 to +689
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these labels really help readability! tysm!

)
)
logger.info(f"Storing Zarr with {target_chunks =} to {self.get_full_target()}")
rechunked_datasets = indexed_datasets | Rechunk(target_chunks=target_chunks, schema=schema)
target_store = schema | PrepareZarrTarget(

# split datasets according to their write-targets (chunks of bytes)
# then combine datasets with shared write targets
# Note that the pipe (|) operator in beam is just sugar for passing into expand
rechunked_datasets = ChunkToTarget().expand(
indexed_datasets,
target_chunks=beam.pvalue.AsSingleton(target_chunks_pcoll),
schema=beam.pvalue.AsSingleton(schema),
)

target_store = PrepareZarrTarget(
target=self.get_full_target(),
target_chunks=target_chunks,
attrs=self.attrs,
encoding=self.encoding,
append_dim=self.append_dim,
).expand(schema, beam.pvalue.AsSingleton(target_chunks_pcoll))

# Actually attempt to write datasets to their target bytes/files
rechunking = rechunked_datasets | "write chunks" >> beam.Map(
store_dataset_fragment, target_store=beam.pvalue.AsSingleton(target_store)
)
n_target_stores = rechunked_datasets | StoreDatasetFragments(target_store=target_store)

# the last thing we need to do is extract the zarrstore target. To do this *after*
# rechunking, we need to make the dependency on `rechunking` (and its side effects) explicit
singleton_target_store = (
n_target_stores
rechunking
| beam.combiners.Sample.FixedSizeGlobally(1)
| beam.FlatMap(lambda x: x) # https://stackoverflow.com/a/47146582
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
CombineReferences,
DatasetToSchema,
DetermineSchema,
_NestDim,
_nest_dim,
)
from pangeo_forge_recipes.types import CombineOp, Dimension, Index, Position

Expand Down Expand Up @@ -264,12 +264,12 @@ def _check(actual):
input = p | pcoll
group1 = (
input
| "Nest CONCAT" >> _NestDim(Dimension("time", CombineOp.CONCAT))
| "Nest CONCAT" >> beam.Map(_nest_dim, dimension=Dimension("time", CombineOp.CONCAT))
| "Groupby CONCAT" >> beam.GroupByKey()
)
group2 = (
input
| "Nest MERGE" >> _NestDim(Dimension("variable", CombineOp.MERGE))
| "Nest MERGE" >> beam.Map(_nest_dim, dimension=Dimension("variable", CombineOp.MERGE))
| "Groupy MERGE" >> beam.GroupByKey()
)
assert_that(group1, check_key(merge_only_indexes, concat_only_indexes), label="merge")
Expand Down
1 change: 1 addition & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def test_xarray_zarr_consolidate_dimension_coordinates(
target_root=tmp_target,
store_name="subpath",
combine_dims=pattern.combine_dim_keys,
target_chunks={"time": 10, "lat": 18, "lon": 36},
)
| ConsolidateDimensionCoordinates()
| ConsolidateMetadata()
Expand Down
Loading
Loading