-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Named tensors #5048
Comments
Have you started working on this @juliuskunze ? |
We are actually working on something that will pretty much realize the plan that @juliuskunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators). |
Do I assume correctly that this evolved into named axes, or is there another module I did not find? |
That's correct. |
@apaszke @froystig That looks awesome! Rad choice not taking into account order of named axes and broadcasting by name! That's semantically cleaner and probably more future-proof than I expected. (: The thing that I thought would make this impractical is that it's hard to optimize misaligned axes for dot products and similar ops where implicit transposes are needed on device. I guess the performance hit is not so bad or axis order optimization could/should be automated in the future anyway? Curious about your thoughts on this. +1 for allowing arrays and operations with named axes outside of |
A more powerful implementation is to use first-class dimensions, and torchdim uses objects as dimension "variables" |
@apaszke Perhaps it could be further independent of axis position? By utilizing the
# tensor.named_shape={'batch':32, 'time':100, 'hidden':200}
t[{'time':0, 'hidden':0}] = 1000 # Select tensor with axis time 0 and axis hidden 0, and set tensor to 1000 with broadcast.
for t in tensor['time']:
# Jax automatically performs dimension permutation for operations: tensor: batch, time, hidden -> time, batch, hidden
# t.named_shape = {'batch':32, 'hidden':200}
... |
@Bit0r I built your suggestion into some code I wrote for operating on a |
PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.
It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.
@shoyer In #1565 you mentioned that this could be done by wrapping JAX based on #611. According to my current understanding, this means:
eval_names
transform.NamedDeviceArray
subtype ofDeviceArray
that adds anames
property.NamedDeviceArray
s. For that,jax.numpy
, wrapping each op with thenamed
transform.NamedDeviceArray
using Implement overrides of NumPy's public API on JAX arrays #611 (+1 for merging). Alternatively, one could rewritejax.numpy
usingnumpy_dispatch.get_array_module
from Add experimental __array_module__ method #4076 (appears cumbersome).jit
ted functions propagate names when applied toNamedDeviceArray
s.Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) #611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with #1565. (:
The text was updated successfully, but these errors were encountered: