Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support saving as netcdf InferenceData that has MultiIndex coordinates #2165

Open
lucianopaz opened this issue Nov 16, 2022 · 8 comments
Open

Comments

@lucianopaz
Copy link
Contributor

Tell us about it

There are many situations in which it is very convenient to use pandas.MultiIndex as coordinates of an xarray.DataArray. The problem is that, at the moment, xarray doesn't provide a builtin way to save these indexes in netcdf format. Take for example:

from arviz.tests.helpers import create_model

idata = create_model()
idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
idata.to_netcdf("test.nc")

This raises a NotImplementedError with the following traceback

NotImplementedError                       Traceback (most recent call last)
<ipython-input-15-43e455b97609> in <module>
      3 idata = create_model()
      4 idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
----> 5 idata.to_netcdf("test.nc")

~/anaconda3/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    442                         if _compressible_dtype(values.dtype)
    443                     }
--> 444                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    445                 data.close()
    446                 mode = "a"

~/anaconda3/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1898         from ..backends.api import to_netcdf
   1899 
-> 1900         return to_netcdf(
   1901             self,
   1902             path,

~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1070         # TODO: allow this work (setting up the file for writing array data)
   1071         # to be parallelized with dask
-> 1072         dump_to_store(
   1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )

~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118 
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120 
   1121 

~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    259             writer = ArrayWriter()
    260 
--> 261         variables, attributes = self.encode(variables, attributes)
    262 
    263         self.set_attributes(attributes)

~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    348         # All NetCDF files get CF encoded by default, without this attempting
    349         # to write times, for example, would fail.
--> 350         variables, attributes = cf_encoder(variables, attributes)
    351         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    352         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    857     _update_bounds_encoding(variables)
    858 
--> 859     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    860 
    861     # Remove attrs from bounds variables (issue #2921)

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
    857     _update_bounds_encoding(variables)
    858 
--> 859     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    860 
    861     # Remove attrs from bounds variables (issue #2921)

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    262         A variable which has been encoded as described above.
    263     """
--> 264     ensure_not_multiindex(var, name=name)
    265 
    266     for coder in [

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in ensure_not_multiindex(var, name)
    177 def ensure_not_multiindex(var, name=None):
    178     if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
--> 179         raise NotImplementedError(
    180             "variable {!r} is a MultiIndex, which cannot yet be "
    181             "serialized to netCDF files "

NotImplementedError: variable 'sample' is a MultiIndex, which cannot yet be serialized to netCDF files (https://github.com/pydata/xarray/issues/1077). Use reset_index() to convert MultiIndex levels into coordinate variables instead.

Thoughts on implementation

I had a look at the mentioned xarray issue, and the approach suggested by @dcherian works (at least in the scenario that I had to work with a month ago). I think that it would be good to incorporate something like that into arviz.from_netcdf and InferenceData.to_netcdf. The basic idea is to convert the MultiIndex into a simple array of integers, that are the codes of the MultiIndex, and also add an attribute that states that the dimension/coordinates were originally a MultiIndex. This attribute is also used to keep track of the level values and names of the original MultiIndex. The modified datastructure can be serialized to netcdf without any problems. The only thing to be aware of is that when the netcdf is loaded, some work has to happen to rebuild the MultiIndex from the original coordinates. I think that this small overhead is worth the benefit of bringing MultiIndex support to arviz.

If you all agree that this would be valuable, I can write a PR.

@ahartikainen
Copy link
Contributor

I'm not sure if we want to deviate from netcdf4 spec.

We could have functionality to transform from and to multiindex with suitable info in attrs. But it wouldn't be then part of official spec.

@lucianopaz
Copy link
Contributor Author

I understand @ahartikainen, but I think that there are scenarios where it is very helpful to support MultiIndex coordinates. Specially when you can't simply unstack the dimension, because some combinations of level values have no data in them.

@lucianopaz
Copy link
Contributor Author

It could also be useful for situations in which posterior samples are streamed into the Dataset in no particular order. What I mean by this is that the points of each (chain, draw) pair could come in unpredictable order. For example, you could get

  1. (chain=0, draw=0)
  2. (chain=1, draw=0)
  3. (chain=0, draw=1)
  4. (chain=0, draw=2)
  5. (chain=1, draw=1)
    ...

and you wouldn't have to wait for all chains to finish to collect the posterior samples into a uniform grid (chain, draw). You would be able to concatenate based on a MultiIndex that was defined in terms of chain and draw tuples (i.e. sample=("chain", "draw"))

@AlexAndorra
Copy link
Contributor

I just ran into the same issue, and I agree this would be great to add this to ArviZ! Working with Multi-index is still quite hard in the xarray world, so anything that can make it accessible seems worth it to me.
Would we output a UserWarning in that case @ahartikainen , to make sure people know we're departing of the official spec?

@dcherian
Copy link

I packaged that code in cf-xarray if you want to use that. That version fixes a couple of bugs

@ahartikainen
Copy link
Contributor

The problem is how these things are supposed to work in other languages.

In my own opinion multiindex items can saved as coords against a stacking dimensions.

@OriolAbril
Copy link
Member

I think it would be best to use the implementation in cf-xarray. Is it critical for this to happen automatically when calling to_netcdf?

Some options that come to mind are:

  1. Add an alias to cf-xarray encode/decode functions in arviz namespace but still require calling it explicitly before saving. cf-xarray only depends on xarray, so adding it as a dependency doesn't significantly increase our dependency graph.
  2. Document the use of cf-xarray, maybe also catch that warning and modify the error to point to cf-xarray. Is that something that could happen directly in xarray instead of pointint to the issue @dcherian?
  3. Add cf-xarray as an optional dependency and if installed and multiindex are present then use it within to/from netcdf. Maybe this should only be done if an rcparam is set to true?

@dcherian
Copy link

modify the error to point to cf-xarray.

I think this would be a great PR to xarray!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants