You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've found a behavior in which the output of jt.triton_call differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.
Specifically, for the Triton repo's matmul kernel (source):
(1) jt.triton_callreturns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call
(2) jt.triton_call returns correct results when those metaparameters are selected via triton.autotune (and not directly passed into jt.triton_call)
Also, simply importing Triton's matmul_perf_model(source) further affects this; with the import, the jt.triton_call fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.
I am attaching a script that reproduces this behavior.
I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between args and metaparams seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?
I've found a behavior in which the output of
jt.triton_call
differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.Specifically, for the Triton repo's matmul kernel (source):
(1)
jt.triton_call
returns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call(2)
jt.triton_call
returns correct results when those metaparameters are selected viatriton.autotune
(and not directly passed intojt.triton_call
)Also, simply importing Triton's
matmul_perf_model
(source) further affects this; with the import, thejt.triton_call
fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.I am attaching a script that reproduces this behavior.
I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between
args
andmetaparams
seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?Thanks for the help!
matmul_repro.txt
The text was updated successfully, but these errors were encountered: