jax.jit(foo, device=dev) vs jax.device_put and interpretted vs JIT-ed #22963
-
Hi, I am thinking about tentatively implementing a PJRT plugin and I have some questions about the interfaces for users. At the moment, I imagine my PJRT plugin as taking as an input an MLIR module via this function: // Variant of `Compile` that accepts an MLIR module.
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) = 0; compiling the module to some executable and then later receiving the concrete inputs and executing the function. I really like I think the following two are equivalent but I am not sure. Can you confirm? x = jnp.array(0)
jax.jit(foo, device=dev)(x) and x = jnp.array(0)
x = jax.device_put(x, dev)
jax.jit(foo)(x) I can see this working by Also, I have a question about backends when things are not executed via # no jax.jit context
x = jnp.array(0)
x = jax.device_put(x, dev)
x + x Does this mean that the python line |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Yes, these should be equivalent but only in single device case. If
Yes, that's correct. JAX uses computation follows data semantics.
|
Beta Was this translation helpful? Give feedback.
Yes, these should be equivalent but only in single device case. If
x
is sharded, then they won't be equivalent. Alsojax.jit(f, device=)
is deprecated so you should use the latter.Yes, that's correct. JAX uses computation follows data semantics.
+
is jitted internally, so it will be executed on the default device on the default backend that's present.