Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Add __launch_bounds__ directive as part of the CUDA code generation #8678

Merged
merged 1 commit into from
Aug 7, 2021

Conversation

ArmageddonKnight
Copy link
Contributor

Short Summary

Sometimes, when executing CUDA kernels, we might encounter the error CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES (e.g., here). This happens because the nvcc compiler allocates too many registers per thread. In the case when we launch the CUDA kernel using too many threads, the GPU will notice that the CUDA kernel requests more registers than what are available on the chip and therefore refuse to launch the kernel.

This hence implies that we need a way of telling nvcc what to expect in terms of the number of threads per block. Luckily, the __launch_bounds__ directive can help us achieve what we want. In this patch, we add __launch_bounds__ as part of the CUDA code generation procedure. __launch_bounds__ will be automatically printed if it is detected that the number of threads per block is a constant integer value. Passing this information to nvcc allows it to spill registers if needed, which might hurt performance, but is still better than having a CUDA kernel that is not functional.

Q & A

Q: Would this affect the AutoTVM and the auto-scheduler submodule?

A: No. Although in those cases the number of threads keeps changing at each trial, the number will be set to a constant when it comes to the code generation phase. Furthermore, in the case when the number of threads per block is not a constant, __launch_bounds__ will simply not be printed.

Any feedback on this patch is appreciated. @comaniac @icemelon @yzhliu @yidawang

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @Hzfengsy @Laurawly @masahi @tqchen it would be better if you folks could also take a look, as this will change the generated CUDA codes for all existing workloads.

@junrushao
Copy link
Member

I am very much in favor of this approach. Thanks Bojian!

src/target/source/codegen_cuda.cc Outdated Show resolved Hide resolved
src/target/source/codegen_cuda.cc Show resolved Hide resolved
@junrushao
Copy link
Member

junrushao commented Aug 7, 2021

The syntax should be __launch_bounds__(maxThreadsPerBlock) or __launch_bounds__(maxThreadsPerBlock, minBlocksPerMultiprocessor).

However, in the failing test, I am seeing:

grid=(16,8,1),  block=(32,2,2)
...
__launch_bounds__(1)

Would you like to double check? @ArmageddonKnight

@ArmageddonKnight
Copy link
Contributor Author

ArmageddonKnight commented Aug 7, 2021

@junrushao1994 Thanks for letting me know. I had a look into this issue. The problem is caused by assigning threadIdx.x to iv->thread_tag rather than iv->var->name_hint (and therefore, the extractor is unable to correctly extract the number of threads per block). To address this issue, I extend the extractor to cover thread_tag's as well. At the same time, if the number of threads per block is extracted as 1, then __launch_bounds__ will NOT be printed.

I ran the test case again and it works locally.

image

@junrushao
Copy link
Member

Thanks Bojian! It makes perfect sense to me, which reminds me of a similar bug we encountered in TensorIR lowering :-)

@ArmageddonKnight
Copy link
Contributor Author

@vinx13 @comaniac @junrushao1994 FYI, the patch has passed all the checks. Could you please merge it if possible?

@junrushao junrushao merged commit bca155f into apache:main Aug 7, 2021
@junrushao
Copy link
Member

Thanks Bojian! It’s really super useful improvement in cuda codegen

@ArmageddonKnight
Copy link
Contributor Author

Thanks @junrushao1994

@ArmageddonKnight ArmageddonKnight deleted the bojian/CUDALaunchBounds branch August 7, 2021 14:52
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Aug 11, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants