Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging coordinates computes array values #9481

Open
shoyer opened this issue Sep 11, 2024 · 2 comments
Open

Merging coordinates computes array values #9481

shoyer opened this issue Sep 11, 2024 · 2 comments
Labels

Comments

@shoyer
Copy link
Member

shoyer commented Sep 11, 2024

What is your issue?

Xarray's default handling of coordinate merging (e.g., as used in arithmetic) computes array values, which is not ideal.

(There is probably an older issue to discuss this, but I couldn't find it with a quick search)

This is easiest to see using Dask:

import xarray
import numpy as np
import dask.array

def r(*args):
    raise RuntimeError('data accessed')

x1 = dask.array.from_delayed(dask.delayed(r)(1), shape=(), dtype=np.float64)
x2 = dask.array.from_delayed(dask.delayed(r)(2), shape=(), dtype=np.float64)
ds1 = xarray.Dataset(coords={'x': x1})
ds2 = xarray.Dataset(coords={'x': x2})
ds1 + ds2  # RuntimeError: data accessed

Traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 12
     10 ds1 = xarray.Dataset(coords={'x': x1})
     11 ds2 = xarray.Dataset(coords={'x': x2})
---> 12 ds1 + ds2

File ~/dev/xarray/xarray/core/_typed_ops.py:35, in DatasetOpsMixin.__add__(self, other)
     34 def __add__(self, other: DsCompatible) -> Self:
---> 35     return self._binary_op(other, operator.add)

File ~/dev/xarray/xarray/core/dataset.py:7783, in Dataset._binary_op(self, other, f, reflexive, join)
   7781     self, other = align(self, other, join=align_type, copy=False)
   7782 g = f if not reflexive else lambda x, y: f(y, x)
-> 7783 ds = self._calculate_binary_op(g, other, join=align_type)
   7784 keep_attrs = _get_keep_attrs(default=False)
   7785 if keep_attrs:

File ~/dev/xarray/xarray/core/dataset.py:7844, in Dataset._calculate_binary_op(self, f, other, join, inplace)
   7841     return type(self)(new_data_vars)
   7843 other_coords: Coordinates | None = getattr(other, "coords", None)
-> 7844 ds = self.coords.merge(other_coords)
   7846 if isinstance(other, Dataset):
   7847     new_vars = apply_over_both(
   7848         self.data_vars, other.data_vars, self.variables, other.variables
   7849     )

File ~/dev/xarray/xarray/core/coordinates.py:522, in Coordinates.merge(self, other)
    519 if not isinstance(other, Coordinates):
    520     other = Dataset(coords=other).coords
--> 522 coords, indexes = merge_coordinates_without_align([self, other])
    523 coord_names = set(coords)
    524 return Dataset._construct_direct(
    525     variables=coords, coord_names=coord_names, indexes=indexes
    526 )

File ~/dev/xarray/xarray/core/merge.py:413, in merge_coordinates_without_align(objects, prioritized, exclude_dims, combine_attrs)
    409     filtered = collected
    411 # TODO: indexes should probably be filtered in collected elements
    412 # before merging them
--> 413 merged_coords, merged_indexes = merge_collected(
    414     filtered, prioritized, combine_attrs=combine_attrs
    415 )
    416 merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords))
    418 return merged_coords, merged_indexes

File ~/dev/xarray/xarray/core/merge.py:290, in merge_collected(grouped, prioritized, compat, combine_attrs, equals)
    288 variables = [variable for variable, _ in elements_list]
    289 try:
--> 290     merged_vars[name] = unique_variable(
    291         name, variables, compat, equals.get(name, None)
    292     )
    293 except MergeError:
    294     if compat != "minimal":
    295         # we need more than "minimal" compatibility (for which
    296         # we drop conflicting coordinates)

File ~/dev/xarray/xarray/core/merge.py:137, in unique_variable(name, variables, compat, equals)
    133         break
    135 if equals is None:
    136     # now compare values with minimum number of computes
--> 137     out = out.compute()
    138     for var in variables[1:]:
    139         equals = getattr(out, compat)(var)

File ~/dev/xarray/xarray/core/variable.py:1003, in Variable.compute(self, **kwargs)
    985 """Manually trigger loading of this variable's data from disk or a
    986 remote source into memory and return a new variable. The original is
    987 left unaltered.
   (...)
   1000 dask.array.compute
   1001 """
   1002 new = self.copy(deep=False)
-> 1003 return new.load(**kwargs)

File ~/dev/xarray/xarray/core/variable.py:981, in Variable.load(self, **kwargs)
    964 def load(self, **kwargs):
    965     """Manually trigger loading of this variable's data from disk or a
    966     remote source into memory and return this variable.
    967
   (...)
    979     dask.array.compute
    980     """
--> 981     self._data = to_duck_array(self._data, **kwargs)
    982     return self

File ~/dev/xarray/xarray/namedarray/pycompat.py:130, in to_duck_array(data, **kwargs)
    128 if is_chunked_array(data):
    129     chunkmanager = get_chunked_array_type(data)
--> 130     loaded_data, *_ = chunkmanager.compute(data, **kwargs)  # type: ignore[var-annotated]
    131     return loaded_data
    133 if isinstance(data, ExplicitlyIndexed):

File ~/dev/xarray/xarray/namedarray/daskmanager.py:86, in DaskManager.compute(self, *data, **kwargs)
     81 def compute(
     82     self, *data: Any, **kwargs: Any
     83 ) -> tuple[np.ndarray[Any, _DType_co], ...]:
     84     from dask.array import compute
---> 86     return compute(*data, **kwargs)

File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/dask/base.py:664, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    661     postcomputes.append(x.__dask_postcompute__())
    663 with shorten_traceback():
--> 664     results = schedule(dsk, keys, **kwargs)
    666 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

Cell In[2], line 6, in r(*args)
      5 def r(*args):
----> 6     raise RuntimeError('data accessed')

RuntimeError: data accessed

We use this check to decide whether or not to preserve coordinates on result objects. If coordinates are the same from all arguments, they are kept. Otherwise they are dropped.

There are checks for matching array identity inside the Variable.equals, so in practice this is often skipped, but it isn't ideal. It's basically the only case where Xarray operations on Xarray objects requires computing lazy array values.

The simplest fix would be to switch the default compat option used for merging inside arithmetic (and other xarray internal operations) to "override", so coordinates are simply copied from the first object on which they appear. Would this make sense?

@shoyer shoyer added the needs triage Issue that has not been reviewed by xarray team member label Sep 11, 2024
@dcherian
Copy link
Contributor

the simplest fix would be to switch the default compat option used for merging inside arithmetic (and other xarray internal operations) to "override"

Yes, I believe I added some special casing for dask deep in xarray (check dask_array.name for equality) to make this cheap a long time ago, but "override" does seem better. Should we should also check shape/dtype equality though for a little more correctness?

Related: #8778

@TomNicholas TomNicholas added topic-lazy array topic-combine combine/concat/merge and removed needs triage Issue that has not been reviewed by xarray team member labels Sep 11, 2024
@shoyer
Copy link
Member Author

shoyer commented Sep 11, 2024

Should we should also check shape/dtype equality though for a little more correctness?

We currently don't enforce these (I guess we default to compat='broadcast_equals'?), so it would be a breaking change to add this check. This suggests sticking with compat='override'.

In fact, in the case of mis-matched dtypes, it looks like the result dtype (but not the result shape) already depends on the order of arithmetic arguments :/

ds1 = xarray.Dataset(coords={'x': 1.0})
ds2 = xarray.Dataset(coords={'x': 1})
ds3 = xarray.Dataset(coords={'x': ('y', [1])})
In [13]: ds1 + ds2
Out[13]:
<xarray.Dataset> Size: 8B
Dimensions:  ()
Coordinates:
    x        float64 8B 1.0
Data variables:
    *empty*

In [14]: ds2 + ds1
Out[14]:
<xarray.Dataset> Size: 8B
Dimensions:  ()
Coordinates:
    x        int64 8B 1
Data variables:
    *empty*

In [15]: ds1 + ds3
Out[15]:
<xarray.Dataset> Size: 8B
Dimensions:  (y: 1)
Coordinates:
    x        (y) float64 8B 1.0
Dimensions without coordinates: y
Data variables:
    *empty*

In [16]: ds3 + ds1
Out[16]:
<xarray.Dataset> Size: 8B
Dimensions:  (y: 1)
Coordinates:
    x        (y) int64 8B 1
Dimensions without coordinates: y
Data variables:
    *empty*

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants