We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First reported at jax-ml/jax#22865
This appears to be an issue with one of the fusion passes for the vmapped program.
Here's a more stripped-down repro, with the traceback from a GPU runtime (jax v0.4.31):
import jax import jax.numpy as jnp @jax.vmap def fn(x): R1 = jnp.array([[x[0], 0, 0], [0, x[0], 0], [0, 0, x[0]]]) R2 = jnp.array([[x[0], 0, 0], [0, x[1], 0], [0, 0, x[2]]]) H = jnp.eye(4) H = H.at[:3, :3].set(R2.T) pos = H @ jnp.concatenate([x, jnp.array([1.0])]) return pos, R1 x_v = jnp.zeros((5, 3)) jax.jit(fn).lower(x_v).compile()
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) [<ipython-input-1-11d7cd79e4c8>](https://localhost:8080/#) in <cell line: 18>() 16 17 x_v = jnp.zeros((5, 3)) ---> 18 jax.jit(fn).lower(x_v).compile() [/usr/local/lib/python3.10/dist-packages/jax/_src/stages.py](https://localhost:8080/#) in compile(self, compiler_options) 673 kw: dict[str, Any] = {"compiler_options": compiler_options} 674 return Compiled( --> 675 self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args 676 self.args_info, 677 self.out_tree, [/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in compile(self, compiler_options) 2293 def compile(self, compiler_options=None) -> MeshExecutable: 2294 if self._executable is None or compiler_options is not None: -> 2295 executable = UnloadedMeshExecutable.from_hlo( 2296 self._name, self._hlo, **self.compile_args, 2297 compiler_options=compiler_options) [/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in from_hlo(***failed resolving arguments***) 2805 break 2806 -> 2807 xla_executable = _cached_compilation( 2808 hlo, name, mesh, spmd_lowering, 2809 tuple_args, auto_spmd_lowering, allow_prop_to_inputs, [/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values, pgle_profiler) 2619 "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", 2620 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): -> 2621 xla_executable = compiler.compile_or_get_cached( 2622 backend, computation, dev, compile_options, host_callbacks, 2623 pgle_profiler) [/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks, pgle_profiler) 397 else: 398 log_persistent_cache_miss(module_name, cache_key) --> 399 return _compile_and_write_cache( 400 backend, 401 computation, [/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in _compile_and_write_cache(backend, computation, compile_options, host_callbacks, module_name, cache_key) 625 ) -> xc.LoadedExecutable: 626 start_time = time.monotonic() --> 627 executable = backend_compile( 628 backend, computation, compile_options, host_callbacks 629 ) [/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs) 334 def wrapper(*args, **kwargs): 335 with TraceAnnotation(name, **decorator_kwargs): --> 336 return func(*args, **kwargs) 337 return wrapper 338 return wrapper [/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks) 265 # TODO(sharadmv): remove this fallback when all backends allow `compile` 266 # to take in `host_callbacks` --> 267 return backend.compile(built_c, compile_options=options) 268 269 def compile_or_get_cached( XlaRuntimeError: INVALID_ARGUMENT: Binary op with incompatible shapes: f32[4,5,4] and f32[5,4,4].
The HLO that is sent to XLA looks like this:
print(jax.jit(fn).lower(x_v).as_text())
module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<5x3xf32> {mhlo.layout_mode = "default"}) -> (tensor<5x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<5x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { %cst = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> %0 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32> %1 = stablehlo.reshape %0 : (tensor<5x1xf32>) -> tensor<5xf32> %2 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32> %3 = stablehlo.reshape %2 : (tensor<5x1xf32>) -> tensor<5xf32> %4 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32> %5 = stablehlo.reshape %4 : (tensor<5x1xf32>) -> tensor<5xf32> %6 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %7 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<1xf32> %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<1xf32> %9 = stablehlo.broadcast_in_dim %7, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %10 = stablehlo.broadcast_in_dim %8, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %11 = stablehlo.concatenate %6, %9, %10, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %12 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<1xf32> %13 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<1xf32> %15 = stablehlo.broadcast_in_dim %12, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %16 = stablehlo.broadcast_in_dim %14, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %17 = stablehlo.concatenate %15, %13, %16, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %18 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<1xf32> %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %19 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<1xf32> %20 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %21 = stablehlo.broadcast_in_dim %18, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %22 = stablehlo.broadcast_in_dim %19, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %23 = stablehlo.concatenate %21, %22, %20, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %24 = stablehlo.broadcast_in_dim %11, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %25 = stablehlo.broadcast_in_dim %17, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %26 = stablehlo.broadcast_in_dim %23, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %27 = stablehlo.concatenate %24, %25, %26, dim = 1 : (tensor<5x1x3xf32>, tensor<5x1x3xf32>, tensor<5x1x3xf32>) -> tensor<5x3x3xf32> %28 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32> %29 = stablehlo.reshape %28 : (tensor<5x1xf32>) -> tensor<5xf32> %30 = stablehlo.slice %arg0 [0:5, 1:2] : (tensor<5x3xf32>) -> tensor<5x1xf32> %31 = stablehlo.reshape %30 : (tensor<5x1xf32>) -> tensor<5xf32> %32 = stablehlo.slice %arg0 [0:5, 2:3] : (tensor<5x3xf32>) -> tensor<5x1xf32> %33 = stablehlo.reshape %32 : (tensor<5x1xf32>) -> tensor<5xf32> %34 = stablehlo.broadcast_in_dim %29, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %35 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<1xf32> %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %36 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<1xf32> %37 = stablehlo.broadcast_in_dim %35, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %38 = stablehlo.broadcast_in_dim %36, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %39 = stablehlo.concatenate %34, %37, %38, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %40 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<1xf32> %41 = stablehlo.broadcast_in_dim %31, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %42 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<1xf32> %43 = stablehlo.broadcast_in_dim %40, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %44 = stablehlo.broadcast_in_dim %42, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %45 = stablehlo.concatenate %43, %41, %44, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %46 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<1xf32> %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %47 = stablehlo.broadcast_in_dim %cst_11, dims = [] : (tensor<f32>) -> tensor<1xf32> %48 = stablehlo.broadcast_in_dim %33, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32> %49 = stablehlo.broadcast_in_dim %46, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %50 = stablehlo.broadcast_in_dim %47, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %51 = stablehlo.concatenate %49, %50, %48, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32> %52 = stablehlo.broadcast_in_dim %39, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %53 = stablehlo.broadcast_in_dim %45, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %54 = stablehlo.broadcast_in_dim %51, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32> %55 = stablehlo.concatenate %52, %53, %54, dim = 1 : (tensor<5x1x3xf32>, tensor<5x1x3xf32>, tensor<5x1x3xf32>) -> tensor<5x3x3xf32> %56 = stablehlo.iota dim = 0 : tensor<4x4xi32> %57 = stablehlo.iota dim = 1 : tensor<4x4xi32> %c = stablehlo.constant dense<0> : tensor<i32> %58 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<4x4xi32> %59 = stablehlo.add %56, %58 : tensor<4x4xi32> %60 = stablehlo.compare EQ, %59, %57, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> %61 = stablehlo.convert %60 : (tensor<4x4xi1>) -> tensor<4x4xf32> %62 = stablehlo.transpose %55, dims = [0, 2, 1] : (tensor<5x3x3xf32>) -> tensor<5x3x3xf32> %c_12 = stablehlo.constant dense<0> : tensor<i32> %63 = stablehlo.broadcast_in_dim %c_12, dims = [] : (tensor<i32>) -> tensor<1xi32> %c_13 = stablehlo.constant dense<0> : tensor<i32> %64 = stablehlo.broadcast_in_dim %c_13, dims = [] : (tensor<i32>) -> tensor<1xi32> %65 = stablehlo.concatenate %63, %64, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %66 = stablehlo.broadcast_in_dim %61, dims = [1, 2] : (tensor<4x4xf32>) -> tensor<5x4x4xf32> %67 = "stablehlo.scatter"(%66, %65, %62) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 1, 2], scatter_dims_to_operand_dims = [1, 2]>, unique_indices = true}> ({ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): stablehlo.return %arg2 : tensor<f32> }) : (tensor<5x4x4xf32>, tensor<2xi32>, tensor<5x3x3xf32>) -> tensor<5x4x4xf32> %68 = stablehlo.broadcast_in_dim %cst, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32> %69 = stablehlo.concatenate %arg0, %68, dim = 1 : (tensor<5x3xf32>, tensor<5x1xf32>) -> tensor<5x4xf32> %70 = stablehlo.dot_general %67, %69, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<5x4x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32> return %70, %27 : tensor<5x4xf32>, tensor<5x3x3xf32> } }
The text was updated successfully, but these errors were encountered:
No branches or pull requests
First reported at jax-ml/jax#22865
This appears to be an issue with one of the fusion passes for the vmapped program.
Here's a more stripped-down repro, with the traceback from a GPU runtime (jax v0.4.31):
Traceback
The HLO that is sent to XLA looks like this:
output
The text was updated successfully, but these errors were encountered: