-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pure_callback is broken with multiple vmap #23624
Comments
Similar to #17187, not sure I follow the logic of this comment |
This actually isn't the behavior of def broadcasting_vmap(f):
f = jax.custom_batching.custom_vmap(f)
@f.def_vmap
def rule(axis_size, in_batched, *args):
batched_args = jax.tree.map(
lambda x, b: x if b else jax.lax.broadcast(x, (axis_size,)), args,
tuple(in_batched))
out = f(*batched_args)
out_batched = jax.tree.map(lambda _: True, out)
return out, out_batched
return f |
It has come up a few times (most recently in jax-ml#23624) that the "vectorized" behavior of `pure_callback` and `ffi_call` is confusing. I'm working on improving that, but in the meantime, it seems like it would be useful to provide a `broadcasting_vmap` similar to the `sequential_vmap` helper that we currently have for vmapping with a `lax.map`.
It has come up a few times (most recently in jax-ml#23624) that the "vectorized" behavior of `pure_callback` and `ffi_call` is confusing. I'm working on improving that, but in the meantime, it seems like it would be useful to provide a `broadcasting_vmap` similar to the `sequential_vmap` helper that we currently have for vmapping with a `lax.map`.
It might be just the trick. However, can I suggest you make sure it pass this?
|
Are you sure you want @partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap # <--------------------------- HERE
def cb_vec(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == (4, 5)
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z) |
The problem is that z should be a scalar inside the func, not broadcasted. Note this is not ufunc behaviour but _is_ what I am looking for. Mapped args are broadcasted. Unmapped are not.On Sept 13, 2024 16:21, Dan Foreman-Mackey ***@***.***> wrote:
Are you sure you want assert z.shape == ()? My suggestion was that you write:
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap # <--------------------------- HERE
def cb_vec(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == (4, 5)
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: ***@***.***>
|
I don't think there's any good way to get that behavior. The inner vmap doesn't "know" about the outer one so I expect you'll be hard pressed to come up with consistent logic to end up with z a scalar. One thing you probably could get would be to get shapes A possible implementationdef joshuaalbert_vmap(f):
f = jax.custom_batching.custom_vmap(f)
@f.def_vmap
def rule(axis_size, in_batched, *args):
batched_args = jax.tree.map(
lambda x, b: x if b else jax.lax.broadcast(x, (1,)), args, # <- 1 instead of axis_size
tuple(in_batched))
out = f(*batched_args)
out_batched = jax.tree.map(lambda _: True, out)
return out, out_batched
return f
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@joshuaalbert_vmap
def cb_broadcasting(x, y, z):
def add(x, y, z):
assert x.shape == (4, 1)
assert y.shape == (1, 5)
assert z.shape == (1, 1)
return x + y + z
out_shape = jnp.broadcast_shapes(x.shape, y.shape, z.shape) # <-- note here
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=out_shape, dtype=x.dtype), x, y, z) The issue is that there needs to be some logic for which arguments to broadcast in each vmap and that can't depend on whether or not an argument is going to be mapped in the future. "vectorized" handles this by never mapping anything that isn't mapped, and I think that it's unlikely that we could come up with sensible logic to get exactly what you're asking for here. All that to say, I do think that you might be able to come up with something that works for your use case using |
It has come up a few times (most recently in jax-ml#23624) that the "vectorized" behavior of `pure_callback` and `ffi_call` is confusing. I'm working on improving that, but in the meantime, it seems like it would be useful to provide a `broadcasting_vmap` similar to the `sequential_vmap` helper that we currently have for vmapping with a `lax.map`.
I understand the constraint. Hmm, perhaps there is another middle ground. In principle, if an argument should never be broadcasted, then it can be curried. The remaining args then can receive broadcasting to convert the function into a ufunc style func. I think in effort to make the API clear, you might merge both above broadcast choices, and rename to
With this setup the original intent of this issue is resolved, i.e. we can now trust that applying vmap multiple times gives consistent shapes inside the callback, which allows easier reasoning. |
Description
When
vectorized=True
the expectation is that thecallback
ofpure_callback
should vectorise over common leading batch dims. That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. If an array has not been mapped then it should not receive a batch dim. If this is violated then it is impossible for thecallback
to construct the proper output shape.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: