Skip to content

How to Verify In-Place Updates by the Compiler in JAX #22930

Answered by jakevdp
bheijden asked this question in Q&A
Discussion options

You must be logged in to vote

There's no direct way to check for this, but you can get an idea of what the compiler is doing with your code by printing the optimized HLO. For example:

import jax

def f(x):
  x = x.at[0].set(1)
  return x.sum()

x = jax.numpy.arange(10)

print(jax.jit(f).lower(x).compile().as_text())
HloModule jit_f, entry_computation_layout={(s32[10]{0})->s32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

%region_1.9 (Arg_0.10: s32[], Arg_1.11: s32[]) -> s32[] {
  %Arg_0.10 = s32[] parameter(0)
  %Arg_1.11 = s32[] parameter(1)
  ROOT %add.12 = s32[] add(s32[] %Arg_0.10, s32[] %Arg_1.11), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@bheijden
Comment options

@jakevdp
Comment options

Answer selected by bheijden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants