Replies: 2 comments 4 replies
-
You can probably implement what you want using the import jax.numpy as jnp
x = jnp.array([1, 2, 3])
bits = jnp.unpackbits(x.astype('uint8')).reshape(len(x), 8)[:, 3:]
print(bits)
Note that this will not work with grad, because the inputs and outputs are integers. I'm not sure it's really meaningful to talk about the gradient of an operation like this (gradients involve taking limits of infinitessimal changes, which rules-out integers) If there's a particular application you have in mind, it might help to say more about it. |
Beta Was this translation helpful? Give feedback.
2 replies
-
@jakevdp Is there an implementation for numbers more than 8 bits, or a way to indicate an overflow? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am struggling to think of a way to convert an array (or matrix) of integers to a matrix of binary values of a certain size that is both jit-able and grad-able. E.g. convert([1, 2, 3], 5) -> [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1]]. I have looked on the online and on the discussion forum and I didn't see anyone who encountered this before. My current implementation is:
Which computes what I want (and works with
jit
) but it doesn't seem to work withgrad
. There is a binary_repr function in numpy but not in jax. There is also an unpack bits function, but that isn't quite the same as this problem. I am curious if anyone has encountered this before, or has any suggestions. Thanks.Beta Was this translation helpful? Give feedback.
All reactions