You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, while applying FusedAttention with jax-triton, we got the following XLA error happens on Nvidia-A100:
2023-08-28 03:06:51.319566: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.319790: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.319901: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320187: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320240: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320386: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320465: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320846: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/paxml/paxml/main.py", line 510, in
app.run(main, flags_parser=absl_flags.flags_parser)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/paxml/paxml/main.py", line 445, in main
_main(argv)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 487, in _main
run(experiment_config=experiment_config,
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 420, in run
run_experiment(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 285, in run_experiment
train.train_and_evaluate(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/train.py", line 274, in train_and_evaluate
executor.start()
File "/workspace/paxml/paxml/executors.py", line 269, in start
_train_and_evaluate_common(
File "/workspace/paxml/paxml/executors.py", line 406, in _train_and_evaluate_common
program_output = train_program.run(partitioned_train_state, step_i)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/programs.py", line 332, in run
new_step, new_state, train_outputs = self.train_step(
File "/workspace/paxml/paxml/programs.py", line 620, in train_step
return step + 1, *train_step(state, prng_key, inputs, static_args)
File "/workspace/paxml/paxml/trainer_lib.py", line 1634, in call
return pjitted_fn(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Steps for reproducing:
Add model variants to /root/.local/lib/python3.10/site-packages/paxml/tasks/lm/params/nvidia.py
Description
Hi, while applying FusedAttention with jax-triton, we got the following XLA error happens on Nvidia-A100:
2023-08-28 03:06:51.319566: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.319790: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.319901: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320187: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320240: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320386: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320465: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320846: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/paxml/paxml/main.py", line 510, in
app.run(main, flags_parser=absl_flags.flags_parser)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/paxml/paxml/main.py", line 445, in main
_main(argv)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 487, in _main
run(experiment_config=experiment_config,
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 420, in run
run_experiment(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 285, in run_experiment
train.train_and_evaluate(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/train.py", line 274, in train_and_evaluate
executor.start()
File "/workspace/paxml/paxml/executors.py", line 269, in start
_train_and_evaluate_common(
File "/workspace/paxml/paxml/executors.py", line 406, in _train_and_evaluate_common
program_output = train_program.run(partitioned_train_state, step_i)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/programs.py", line 332, in run
new_step, new_state, train_outputs = self.train_step(
File "/workspace/paxml/paxml/programs.py", line 620, in train_step
return step + 1, *train_step(state, prng_key, inputs, static_args)
File "/workspace/paxml/paxml/trainer_lib.py", line 1634, in call
return pjitted_fn(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Steps for reproducing:
Add model variants to /root/.local/lib/python3.10/site-packages/paxml/tasks/lm/params/nvidia.py
Run w/o FusedAttention (PASS case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7B
Run w FusedAttention (FAILED case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7BFA
Versions:
NVIDIA GPU info
4 A100-SXM-80GB GPUs
The text was updated successfully, but these errors were encountered: