Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pure_callback is broken with multiple vmap #23624

Open
Joshuaalbert opened this issue Sep 13, 2024 · 7 comments
Open

pure_callback is broken with multiple vmap #23624

Joshuaalbert opened this issue Sep 13, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@Joshuaalbert
Copy link
Contributor

Description

When vectorized=True the expectation is that the callback of pure_callback should vectorise over common leading batch dims. That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side. If an array has not been mapped then it should not receive a batch dim. If this is violated then it is impossible for the callback to construct the proper output shape.

from functools import partial
import jax
import jax.numpy as jnp


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def add_vmapped(x, y, z):
    return x + y + z


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_no_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == ()
        assert y.shape == ()
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=False)


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert add_vmapped(x, y, z).shape == (4, 5)
    assert cb_no_vec(x, y, z).shape == (4, 5)
    assert cb_vec(x, y, z).shape == (4, 5)

System info (python version, jaxlib version, accelerator, etc.)

jax==0.4.31
jaxlib==0.4.31
@Joshuaalbert Joshuaalbert added the bug Something isn't working label Sep 13, 2024
@Joshuaalbert
Copy link
Contributor Author

Similar to #17187, not sure I follow the logic of this comment

@dfm
Copy link
Collaborator

dfm commented Sep 13, 2024

That is, all batch dims of any mapped array should be identical, with shape broadcasting performed on JAX-side.

This actually isn't the behavior of vectorized! I know that the way it's presented in the docs is confusing, and I'm actually pushing to deprecate the vectorized behavior in favor of a more expressive API. I think that what you want is something like a "broadcasting vmap", which can be built using custom_vmap. Something like the following should do the trick:

def broadcasting_vmap(f):
  f = jax.custom_batching.custom_vmap(f)

  @f.def_vmap
  def rule(axis_size, in_batched, *args):
    batched_args = jax.tree.map(
        lambda x, b: x if b else jax.lax.broadcast(x, (axis_size,)), args,
        tuple(in_batched))
    out = f(*batched_args)
    out_batched = jax.tree.map(lambda _: True, out)
    return out, out_batched

  return f

dfm added a commit to dfm/jax that referenced this issue Sep 13, 2024
It has come up a few times (most recently in
jax-ml#23624) that the "vectorized"
behavior of `pure_callback` and `ffi_call` is confusing. I'm working on
improving that, but in the meantime, it seems like it would be useful to
provide a `broadcasting_vmap` similar to the `sequential_vmap` helper
that we currently have for vmapping with a `lax.map`.
dfm added a commit to dfm/jax that referenced this issue Sep 13, 2024
It has come up a few times (most recently in
jax-ml#23624) that the "vectorized"
behavior of `pure_callback` and `ffi_call` is confusing. I'm working on
improving that, but in the meantime, it seems like it would be useful to
provide a `broadcasting_vmap` similar to the `sequential_vmap` helper
that we currently have for vmapping with a `lax.map`.
@Joshuaalbert
Copy link
Contributor Author

It might be just the trick. However, can I suggest you make sure it pass this?

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert cb_vec(x, y, z).shape == (4, 5)

@dfm
Copy link
Collaborator

dfm commented Sep 13, 2024

Are you sure you want assert z.shape == ()? My suggestion was that you write:

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@broadcasting_vmap  # <--------------------------- HERE
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == (4, 5)
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z)

@Joshuaalbert
Copy link
Contributor Author

Joshuaalbert commented Sep 13, 2024 via email

@dfm
Copy link
Collaborator

dfm commented Sep 13, 2024

I don't think there's any good way to get that behavior. The inner vmap doesn't "know" about the outer one so I expect you'll be hard pressed to come up with consistent logic to end up with z a scalar. One thing you probably could get would be to get shapes (4, 1), (1, 5), and (1, 1) if that's better for your use case:

A possible implementation
def joshuaalbert_vmap(f):
  f = jax.custom_batching.custom_vmap(f)

  @f.def_vmap
  def rule(axis_size, in_batched, *args):
    batched_args = jax.tree.map(
        lambda x, b: x if b else jax.lax.broadcast(x, (1,)), args,  # <- 1 instead of axis_size
        tuple(in_batched))
    out = f(*batched_args)
    out_batched = jax.tree.map(lambda _: True, out)
    return out, out_batched

  return f

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@joshuaalbert_vmap
def cb_broadcasting(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 1)
        assert y.shape == (1, 5)
        assert z.shape == (1, 1)
        return x + y + z
    out_shape = jnp.broadcast_shapes(x.shape, y.shape, z.shape)  # <-- note here
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=out_shape, dtype=x.dtype), x, y, z)

The issue is that there needs to be some logic for which arguments to broadcast in each vmap and that can't depend on whether or not an argument is going to be mapped in the future. "vectorized" handles this by never mapping anything that isn't mapped, and I think that it's unlikely that we could come up with sensible logic to get exactly what you're asking for here. All that to say, I do think that you might be able to come up with something that works for your use case using custom_vmap and maybe that will help clarifying your feature request.

dfm added a commit to dfm/jax that referenced this issue Sep 13, 2024
It has come up a few times (most recently in
jax-ml#23624) that the "vectorized"
behavior of `pure_callback` and `ffi_call` is confusing. I'm working on
improving that, but in the meantime, it seems like it would be useful to
provide a `broadcasting_vmap` similar to the `sequential_vmap` helper
that we currently have for vmapping with a `lax.map`.
@Joshuaalbert
Copy link
Contributor Author

I understand the constraint. Hmm, perhaps there is another middle ground. In principle, if an argument should never be broadcasted, then it can be curried. The remaining args then can receive broadcasting to convert the function into a ufunc style func. I think in effort to make the API clear, you might merge both above broadcast choices, and rename to convert_to_ufunc with a tile boolean which determines if the array shapes should broadcasted beforehand.

def convert_to_ufunc(f, tile: bool = True):
    f = jax.custom_batching.custom_vmap(f)

    @f.def_vmap
    def rule(axis_size, in_batched, *args):
        batched_args = jax.tree.map(
            lambda x, b: x if b else jax.lax.broadcast(x, ((axis_size if tile else 1),)), args,
            tuple(in_batched))
        out = f(*batched_args)
        out_batched = jax.tree.map(lambda _: True, out)
        return out, out_batched

    return f

def cb(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5) # if tile=False
        assert y.shape == (4, 5) # if tile=False
        assert z.shape == ()
        return x + y + z

    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
                             y, z, vectorized=True)

# Curry z first
assert jax.vmap(jax.vmap(convert_to_ufunc(partial(cb, z=z)), in_axes=(None, 0)), in_axes=(0, None))(x, y).shape == (4, 5)

With this setup the original intent of this issue is resolved, i.e. we can now trust that applying vmap multiple times gives consistent shapes inside the callback, which allows easier reasoning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants