Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA device API & VerifyGPUCode pass update #5898

Merged
merged 4 commits into from
Jun 25, 2020

Conversation

jcf94
Copy link
Contributor

@jcf94 jcf94 commented Jun 23, 2020

This pr is part of #5883 , consists of 2 small improvements.

  1. Add kMaxRegistersPerBlock to device api, mainly used for CUDA api to get the local memory limitation;

  2. Add vectorize check to VerifyGPUCode pass depending on the vectorized block data size, nvcc has more strict limitations than llvm

The rest 2 bug fix will be in seperate PRs.

3. Update RewriteSimplify to fix a bug occurred in vectorized GPU shared memory cooperative fetching.
Test cases are in `tests/python/unittest/test_te_schedule_gpu_advanced.py`, see `test_vectorized_cooperative_fetching_x()` and `test_vectorized_cooperative_fetching_xy()` for more information.

Code generated with bug is shown like this:
```
A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
(float32x4*)A_2[(((broadcast(((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)), 4) + (floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))*broadcast(512, 4))) + broadcast((k.outer.outer*64), 4)) + floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)))])
```
Which will finally lower to wrong CUDA C instructions.
This should be simplified to generate the correct RampNode:
```
A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
(float32x4*)A_2[ramp((((((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)) + (floordiv(threadIdx.x_1, 16)*512)) + (k.outer.outer*64)) + (floormod(threadIdx.x_1, 16)*4)), 1, 4)])
```
4. Add LegalizeInvalidAttach in `schedule_dataflow_rewrite.cc`
This function legalizes the compute_at location if the target iterator of compute_at was split or fused. The following two cases will crash if we don't use this function.

```python
import tvm
from tvm import te

A = te.compute((10, 10), lambda i, j: 1.0, name='A')
B = te.compute((10, 10), lambda i, j: A[i][j], name='B')

# Case 1: Split an axis which is the target of a compute_at
s = te.create_schedule([B.op])
s[A].compute_at(s[B], B.op.axis[1])
s[B].split(B.op.axis[1], 2)

print(tvm.lower(s, [A, B], simple_mode=True))

# Case 2: Fuse an axis which is the target of a compute_at
s = te.create_schedule([B.op])
s[A].compute_at(s[B], B.op.axis[1])
s[B].fuse(B.op.axis[0], B.op.axis[1])

print(tvm.lower(s, [A, B], simple_mode=True))
```

We rebuild the attach relation in`Schedule::normalize()` to fix this.

@jcf94
Copy link
Contributor Author

jcf94 commented Jun 23, 2020

Seems the extra simplify rule caused an UT error, I'll try to fix that.

src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
src/te/schedule/schedule_dataflow_rewrite.cc Outdated Show resolved Hide resolved
@tqchen
Copy link
Member

tqchen commented Jun 23, 2020

Please split the rewrite simplify into its own PR and add unittestcases for the newly added rules. It would be great if we can write down a few lines of proof sketches, to make sure that these rules are correct.

@jcf94
Copy link
Contributor Author

jcf94 commented Jun 24, 2020

Please split the rewrite simplify into its own PR and add unittestcases for the newly added rules. It would be great if we can write down a few lines of proof sketches, to make sure that these rules are correct.

Thanks, I'll reconsider about the simplify rules to make sure it wouldn't bring extra bugs and split it to another PR.

@jcf94 jcf94 force-pushed the jcf_github_tvm/gpu_related_bug_fix branch from 6d3b7b4 to 736441a Compare June 24, 2020 05:36
@jcf94 jcf94 changed the title GPU related bug fix & Improve CUDA device API & VerifyGPUCode pass update Jun 24, 2020
src/tir/analysis/verify_gpu_code.cc Outdated Show resolved Hide resolved
src/tir/analysis/verify_gpu_code.cc Outdated Show resolved Hide resolved
@merrymercy merrymercy self-assigned this Jun 24, 2020
@jcf94 jcf94 requested a review from merrymercy June 25, 2020 03:44
@jcf94
Copy link
Contributor Author

jcf94 commented Jun 25, 2020

@merrymercy cast fixed & CI cleaned

@merrymercy merrymercy merged commit 074a07e into apache:master Jun 25, 2020
@jcf94 jcf94 deleted the jcf_github_tvm/gpu_related_bug_fix branch June 25, 2020 05:56
t-vi added a commit to t-vi/tvm that referenced this pull request Jun 25, 2020
- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
- I'm throwing in an addition of kMaxRegistersPerBlock for ROCm.
  This was introduced for CUDA in apache#5898.
tqchen pushed a commit that referenced this pull request Jun 25, 2020
- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
- I'm throwing in an addition of kMaxRegistersPerBlock for ROCm.
  This was introduced for CUDA in #5898.
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
* Add kMaxRegistersPerBlock device api for cuda

* Add vectorize check to verify_gpu_code

* Lint fix

* Cast fix
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
…5920)

- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
- I'm throwing in an addition of kMaxRegistersPerBlock for ROCm.
  This was introduced for CUDA in apache#5898.
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
* Add kMaxRegistersPerBlock device api for cuda

* Add vectorize check to verify_gpu_code

* Lint fix

* Cast fix
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
…5920)

- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
- I'm throwing in an addition of kMaxRegistersPerBlock for ROCm.
  This was introduced for CUDA in apache#5898.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants