Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pytorch as backend for xarrays #3232

Open
fjanoos opened this issue Aug 19, 2019 · 49 comments · Fixed by #6804
Open

Use pytorch as backend for xarrays #3232

fjanoos opened this issue Aug 19, 2019 · 49 comments · Fixed by #6804
Labels
topic-arrays related to flexible array support upstream issue

Comments

@fjanoos
Copy link

fjanoos commented Aug 19, 2019

I would be interested in using pytorch as a backend for xarrays - because:
a) pytorch is very similar to numpy - so the conceptual overhead is small
b) [most helpful] enable having a GPU as the underlying hardware for compute - which would provide non-trivial speed up
c) it would allow seamless integration with deep-learning algorithms and techniques

Any thoughts on what the interest for such a feature might be ? I would be open to implementing parts of it - so any suggestions on where I could start ?

Thanks

@shoyer
Copy link
Member

shoyer commented Aug 20, 2019

If pytorch implements overrides of NumPy's API via the __array_function__ protocol, then this could work with minimal effort. We are already using this to support sparse arrays (this isn't an official release yet, but functionality is working in the development version).

I think there has been some discussion about this, but I don't know the current status (CC @rgommers). The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API.

Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with JAX, which already implements NumPy's API almost exactly. I have an experimental pull request adding __array_function__ to JAX, but it still needs a bit of work to finish it up, e.g., we probably want to hide this behind a flag at first.

@rgommers
Copy link

I think there has been some discussion about this, but I don't know the current status (CC @rgommers).

The PyTorch team is definitely receptive to the idea of adding __array_function__ and __array_ufunc__, as well as expanding the API for better NumPy compatibility.

Also, they want a Tensor.__torch_function__ styled after __array_function__ so they can make their own API overridable.

The tracking issue for all of this is pytorch/pytorch#22402

The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API.

Agreed. No one is working on __array_function__ at the moment. Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray. It's not a small project, but implementing a prototype with a few function in the torch namespace that are not exactly matching the NumPy API would be a useful way to start pushing this forward.

@rgommers
Copy link

Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with JAX, which already implements NumPy's API almost exactly.

Less familiar with that, but pytorch does have experimental XLA support, so that's a start.

@shoyer
Copy link
Member

shoyer commented Aug 20, 2019

Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray.

Yes, this is a concern for JAX as well. This is a definite downside of reusing NumPy's existing namespace.

It turns out even xarray was relying on this behavior with dask in at least one edge case: #3215

@rgommers
Copy link

This is a definite downside of reusing NumPy's existing namespace.

We didn't discuss an alternative very explicitly I think, but at least we'll have wide adoption fast. Hopefully the pain is limited ....

@fjanoos
Copy link
Author

fjanoos commented Aug 23, 2019

I haven't used JAX - but was just browsing through its documentation and it looks super cool. Any ideas on how it compares with Pytorch in terms of:

a) Cxecution speed, esp. on GPU
b) Memory management on GPUs. Pytorch has the 'Dataloader/Dataset' paradigm which uses background multithreading to shuttle batches of data back and forth - along with a lot of tips and tricks on efficient memory usage.
c) support for deep-learning optimization algorithms ?

@shoyer
Copy link
Member

shoyer commented Aug 23, 2019

Within a jit compiled function, JAX's execution speed should be quite competitive on GPUs. It uses the XLA compiler, which was recently enabled by default in TensorFlow.

For data loading and deep learning algorithms, take a look at the examples in the notebooks directory in the JAX repo. The APIs for deep learning in JAX are still undergoing rapid development, so APIs are not quite as stable/usable as pytorch or keras yet, but they are quite capable. See jax.experimental.stax and tensor2tensor.trax for examples.

@fjanoos
Copy link
Author

fjanoos commented Aug 23, 2019

While it is pretty straightforward to implement a lot of standard xarray operations with a pytorch / Jax backend (since they just fallback on native functions) - it will be interesting to think about how to implement rolling operations / expanding / exponential window in a way that is both efficient and maintains differentiability.

Expanding and exponential window operations would be easy to do leveraging RNN semantics - but doing rolling using convolutions is going to be very inefficient.

Do you have any thoughts on this?

@shoyer
Copy link
Member

shoyer commented Aug 23, 2019 via email

@fjanoos
Copy link
Author

fjanoos commented Mar 30, 2020

This might be a good time to revive this thread and see if there is wider interest (and bandwidth) in having xarray use CuPy (https://cupy.chainer.org/ ) as a backend (along with numpy). It appears to be a plug-and-play replacement for numpy - so it might not have all the issues that were brought up regarding pytorch/jax ?

Any thoughts ?
cc @mrocklin

@dcherian
Copy link
Contributor

Just chiming in quickly. I think there's definitely interest in doing this through NEP-18.

It looks like CUDA has implemented __array_function__ (https://docs-cupy.chainer.org/en/stable/reference/interoperability.html) so many things may "just work". There was some work earlier on plugging in pydata/sparse, and there is some ongoing work to plug in pint. With both these efforts, a lot of xarray's code should be "backend-agnostic" but its not perfect.

Have you tried creating DataArrays with cupy arrays yet? I would just try things and see what works vs what doesn't.

Practically, our approach so far has been to add a number of xfailed tests (test_sparse.py and test_units.py) and slowly start fixing them. So that's one way to proceed if you're up for it.

@jhamman
Copy link
Member

jhamman commented Mar 30, 2020

@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here.

@jakirkham
Copy link

Yeah Jacob and I played with this a few months back. There were some issues, but my recollection is pretty hazy. If someone gives this another try, it would be interesting to hear how things go.

@fjanoos
Copy link
Author

fjanoos commented Mar 31, 2020 via email

@jakirkham
Copy link

Well here's a blogpost on using Dask + CuPy. Maybe start there and build up to using Xarray.

@andersy005
Copy link
Member

@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here.

I've been test driving xarray objects backed by CuPy arrays, and one issue I keep running into is that operations (such as plotting) that expect numpy arrays fail due to xarray's implicit converstion to Numpy arrays via np.asarray(). CuPy decided not to allow implicit conversion to NumPy arrays (see cupy/cupy#3421).

I am wondering whether there is a plan for dealing with this issue?

Here's a small, reproducible example:

[23]: ds.tmin.data.device
      <CUDA Device 0>
[24]: ds.isel(time=0, lev=0).tmin.plot() # Fails
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-21-69a72de2b9fd> in <module>
----> 1 ds.isel(time=0, lev=0).tmin.plot()

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in __call__(self, **kwargs)
    444 
    445     def __call__(self, **kwargs):
--> 446         return plot(self._da, **kwargs)
    447 
    448     @functools.wraps(hist)

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs)
    198     kwargs["ax"] = ax
    199 
--> 200     return plotfunc(darray, **kwargs)
    201 
    202 

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
    684 
    685         # Pass the data as a masked ndarray too
--> 686         zval = darray.to_masked_array(copy=False)
    687 
    688         # Replace pd.Intervals if contained in xval or yval.

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy)
   2325             Masked where invalid values (nan or inf) occur.
   2326         """
-> 2327         values = self.values  # only compute lazy arrays once
   2328         isnull = pd.isnull(values)
   2329         return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in values(self)
    556     def values(self) -> np.ndarray:
    557         """The array's data as a numpy.ndarray"""
--> 558         return self.variable.values
    559 
    560     @values.setter

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in values(self)
    444     def values(self):
    445         """The variable's data as a numpy.ndarray"""
--> 446         return _as_array_or_item(self._data)
    447 
    448     @values.setter

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in _as_array_or_item(data)
    247     TODO: remove this (replace with np.asarray) once these issues are fixed
    248     """
--> 249     data = np.asarray(data)
    250     if data.ndim == 0:
    251         if data.dtype.kind == "M":

/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

ValueError: object __array__ method not producing an array

@jacobtomlinson
Copy link
Contributor

@andersy005 I'm about to start working actively on cupy support in xarray. Would be great to get some of your input.

Cupy requests that instead of calling __array__ you instead call their .get method for explicit conversion to numpy. So we need to add a little compatibility code for this.

@fjanoos
Copy link
Author

fjanoos commented Jul 9, 2020

@andersy005 I'm about to start working actively on cupy support in xarray. Would be great to get some of your input.

Cupy requests that instead of calling __array__ you instead call their .get method for explicit conversion to numpy. So we need to add a little compatibility code for this.

Do you have a sense of the overhead / effort of making jax vs cupy as the gpu backend for xarrays ? One advantage of jax would be built in auto-diff functionality that would enable xarray to be plugged directly into deep learning pipelines. Downside is that it is not as numpy compatible as cupy. How much of a non-starter would this be ?

@jacobtomlinson
Copy link
Contributor

@fjanoos I'm afraid I don't. In RAPIDS we support cupy as our GPU array implementation. So this request has come from the desire to make xarray compatible with the RAPIDS suite of tools.

We commonly see folks using cupy to switch straight over to a tool like pytorch using DLPack. https://docs-cupy.chainer.org/en/stable/reference/interoperability.html#dlpack

But I don't really see #4212 as an effort to make cupy the GPU backend for xarray. I see it as adding support for another backend to xarray. The more the merrier!

@Duane321
Copy link

Duane321 commented Jan 22, 2021

I'd like to cast my vote in favor of getting this functionality in. It would be nice to autodiff through xarray operations.

From reading this and related threads, I'm trying to determine a gameplan to make this happen. I'm not familiar with xarray code, so any guidance would be much appreciated. This is what I'm thinking :

  1. Create a custom subclass of PyTorch's Tensors which meets the duck array required methods and attributes. Since this isn't officially supported, looks like I could run into issues getting this subclass to persist through tensor operations.
  2. Implement the __array_function__ protocol for PyTorch similar to how is demo-ed here.
  3. Pass this custom class into data array constructors and hope the .grad attribute works.

My first attempts at this haven't been successful. Whatever custom class I make and past to the DataArray constructor gets converted to something xarray can handle with this line :

data = as_compatible_data(data)

Any suggestions would be appreciated. I'm hoping to figure out the shortest path to a working prototype.

@Duane321
Copy link

No one is working on array_function at the moment. Implementing it has some backwards compat concerns as well, because people may be relying on np.somefunc(some_torch_tensor) to be coerced to ndarray. It's not a small project, but implementing a prototype with a few function in the torch namespace that are not exactly matching the NumPy API would be a useful way to start pushing this forward.

@rgommers Do you expect this solution to work with a PyTorch Tensor custom subclass? Or is monkey patching necessary?

@rgommers
Copy link

rgommers commented Jan 23, 2021

Create a custom subclass of PyTorch's Tensors which meets the duck array required methods and attributes. Since this isn't officially supported, looks like I could run into issues getting this subclass to persist through tensor operations.

If you use PyTorch 1.7.1 or later, then Tensor subclasses are much better preserved through pytorch functions and operations like slicing. So a custom subclass, adding the attributes and methods Xarray requires for a duck array should be feasible.

data = as_compatible_data(data)

Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.

Note that I do not expect anymore that we'll be adding __array_function__ to torch.Tensor, and certainly not any time soon. My current expectation is that the "get the correct namespace from an array/tensor object directly" from https://numpy.org/neps/nep-0037-array-module.html#how-to-use-get-array-module and https://data-apis.github.io/array-api/latest/ will turn out to be a much better design long-term.

@rgommers
Copy link

Note that your the main work in adding __array_function__ is not the dispatch mechanism, but mapping to 100% compatible APIs. That job should have gotten a lot easier now compared to 9 months ago. PyTorch now has a completely matching fft module, and a ~70% complete linalg module in master. And functions in the main namespace have gained dtype keywords, integer-to-float promotion, and other NumPy compat changes. So it should be feasible to write your custom subclass.

@fjanoos
Copy link
Author

fjanoos commented Jan 23, 2021

@Duane321
While it would be fantastic to have gpu-enabled auto-diff-able xarrays / DataArrays, an interesting development worth looking into are the named tensor in https://pytorch.org/docs/stable/named_tensor.html. This appears to be an attempt to bridge the gap from the that they are making pytorch tensors increasingly dataarray like. I would not be surprised if within the next few iterations they add indexes to the tensors closing the gap even further.

@Duane321
Copy link

While it would be fantastic to have gpu-enabled auto-diff-able xarrays / DataArrays, an interesting development worth looking into are the named tensor in https://pytorch.org/docs/stable/named_tensor.html. This appears to be an attempt to bridge the gap from the that they are making pytorch tensors increasingly dataarray like. I would not be surprised if within the next few iterations they add indexes to the tensors closing the gap even further.

I really hope so. I explored named_tensors at first, but the lack an index for each dimension was a non-starter. So, I'll keep an eye out.

@Duane321
Copy link

Duane321 commented Jan 25, 2021

Note that your the main work in adding array_function is not the dispatch mechanism, but mapping to 100% compatible APIs. That job should have gotten a lot easier now compared to 9 months ago. PyTorch now has a completely matching fft module, and a ~70% complete linalg module in master. And functions in the main namespace have gained dtype keywords, integer-to-float promotion, and other NumPy compat changes. So it should be feasible to write your custom subclass.

Glad to hear there's progress I can lean on. I'll come back with a minimum version that does the API matching for maybe 1-2 methods, just to get feedback on theoverall structure. If it works, I can brute through a lot of the rest 🤞

Looks like you need to patch that internally just a bit, probably adding pytorch to NON_NUMPY_SUPPORTED_ARRAY_TYPES.

Thank you, I hesitate to change xarray code but not anymore.

Note that I do not expect anymore that we'll be adding array_function to torch.Tensor, and certainly not any time soon. My current expectation is that the "get the correct namespace from an array/tensor object directly" from https://numpy.org/neps/nep-0037-array-module.html#how-to-use-get-array-module and https://data-apis.github.io/array-api/latest/ will turn out to be a much better design long-term.

Does this mean I shouldn't fill out __array_function__ in my subclass? Or is this just a forward looking expectation?

@keewis
Copy link
Collaborator

keewis commented Feb 1, 2021

I can't reproduce that:

In [4]: da.loc["a1"]
Out[4]: 
<xarray.DataArray (b: 2)>
tensor([0.4793, 0.7493], dtype=torch.float32)
Coordinates:
    a        <U2 'a1'
  * b        (b) <U2 'b1' 'b2'

with

numpy: 1.19.5
xarray: 0.16.2
pytorch: 1.7.1.post2
pandas: 1.2.1

maybe this is a environment issue?

Edit: the missing feature list includes loc (and sel) because it is currently not possible to have a duck array in a dimension coordinate, so this:

xr.DataArray(
    [0, 1, 2],
    coords={"x": XArrayTensor(torch.Tensor([10, 12, 14]))},
    dims="x",
).loc[{"x": XArrayTensor(torch.Tensor([10, 14]))}]

does not work, but

xr.DataArray(
    XArrayTensor(torch.Tensor([0, 1, 2])),
    coords={"x": [10, 12, 14]},
    dims="x",
).loc[{"x": [10, 14]}]

should work just fine.

@Duane321
Copy link

Duane321 commented Feb 4, 2021

Thank again @keewis , that was indeed the case. It was due to my older PyTorch version (1.6.0)

@keewis
Copy link
Collaborator

keewis commented Feb 26, 2021

@Duane321: with xarray>=0.17.0 you should be able to remove the __getattributes__ trick.

@hjalmarlucius
Copy link

@Duane321 or @keewis do you have the full code example for making this work? I'm a novice on numpy ufuncs and am trying to use get gradients while keeping my xarray coords.

@keewis
Copy link
Collaborator

keewis commented May 31, 2021

I don't, unfortunately (there's the partial example in #3232 (comment), though).

This is nothing usable right now, but the pytorch maintainers are currently looking into providing support for __array_namespace__ (NEP47). Once there has been sufficient progress in both numpy and pytorch we don't have to change much in xarray (i.e. allowing __array_namespace__ instead of __array_ufunc__ / _array_function__ for duck arrays) to make this work without any wrapper code.

You (or anyone interested) might still want to maintain a "pytorch-xarray" convenience library to allow something like arr.torch.grad(dim="x").

@hjalmarlucius
Copy link

Thanks for the prompt response. Would love to contribute but I have to climb the learning curve first.

@keewis
Copy link
Collaborator

keewis commented May 31, 2021

changing the xarray internals is not too much work: we need to get xarray.core.utils.is_duck_array to return true if the object has either __array_namespace__ or __array_ufunc__ and __array_function__ (or all three) defined, and we'd need a short test demonstrating that objects that implement only __array_namespace__ survive unchanged when wrapped by a xarray object (i.e. something like isinstance(xr.DataArray(pytorch_object).mean().data, pytorch.Tensor)).

We might still be a bit too early with this, though: the PR which adds __array_namespace__ to numpy has not been merged into numpy:main yet.

@zaxtax
Copy link

zaxtax commented Jan 14, 2022

@keewis @shoyer now that numpy is merged in numpy/numpy#18585 __array_namespace__ support and pytorch is in the process of add __array_namespace__ support pytorch/pytorch#58743 is it worth exploring adding support through the __array_namespace__ API?

@andersy005 andersy005 added the topic-arrays related to flexible array support label Jan 17, 2022
@tomwhite
Copy link
Contributor

I started having a look at making xarray work with the array API here: tomwhite@c72a1c4. Some basic operations work (preserving the underlying array): tomwhite@929812a. If there's interest, I'd be happy to turn this into a PR with some tests.

@dcherian
Copy link
Contributor

dcherian commented Jul 13, 2022

I'd be happy to turn this into a PR with some tests.

Absolutely!

@tomwhite
Copy link
Contributor

Opened #6804

@hsharrison
Copy link
Contributor

Glad to see progress on this!! 👏

Just curious though, seeing this comment in the PR:

Note: I haven't actually tested this with pytorch (which is the motivating example for #3232).

Are we sure this closes the issue? And, how can we try it out? Even lacking docs, a comment explaining how to set it up would be great, and I can do some testing on my end. I understand that it's an experimental feature.

@tomwhite
Copy link
Contributor

Hi @hsharrison - thanks for offering to do some testing. Here's a little demo script that you could try, by switching numpy.array_api to pytorch: tomwhite@929812a

@hsharrison
Copy link
Contributor

Nice that it's so simple. I think it can't be tested with pytorch until they compete pytorch/pytorch#58743, right?

Or we should just try passing torch.tensor into xarray directly?

@tomwhite
Copy link
Contributor

I think it can't be tested with pytorch until they compete pytorch/pytorch#58743, right?

It needs __array_namespace__ to be defined to activate the new code path.

@hsharrison
Copy link
Contributor

Makes sense, then I'll wait for pytorch/pytorch#58743 to try it.

@jakirkham
Copy link

While it is true to use PyTorch Tensors directly, one would need the Array API implemented in PyTorch. One could use them indirectly by converting them zero-copy to CuPy arrays, which do have Array API support

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-arrays related to flexible array support upstream issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.