-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA device API & VerifyGPUCode pass update #5898
CUDA device API & VerifyGPUCode pass update #5898
Conversation
Seems the extra simplify rule caused an UT error, I'll try to fix that. |
Please split the rewrite simplify into its own PR and add unittestcases for the newly added rules. It would be great if we can write down a few lines of proof sketches, to make sure that these rules are correct. |
Thanks, I'll reconsider about the simplify rules to make sure it wouldn't bring extra bugs and split it to another PR. |
6d3b7b4
to
736441a
Compare
@merrymercy cast fixed & CI cleaned |
- For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
- For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in #5898.
* Add kMaxRegistersPerBlock device api for cuda * Add vectorize check to verify_gpu_code * Lint fix * Cast fix
…5920) - For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
* Add kMaxRegistersPerBlock device api for cuda * Add vectorize check to verify_gpu_code * Lint fix * Cast fix
…5920) - For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
This pr is part of #5883 , consists of 2 small improvements.
Add
kMaxRegistersPerBlock
to device api, mainly used for CUDA api to get the local memory limitation;Add vectorize check to VerifyGPUCode pass depending on the vectorized block data size, nvcc has more strict limitations than llvm
The rest 2 bug fix will be in seperate PRs.