Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][CUDA] Improve the performance of scatter_nd #8479

Merged

Conversation

zhuwenxi
Copy link
Contributor

  1. Split into 2 kernels, one does the "Init" and another does the "Update".
    Thus they can have different Grid/Block configurations to better utilize
    SMs.
  2. Use atomic_add instead of direct assignment, which could avoid the race
    condtion when multiple indices point to the same location of the output
    tensor. With this moidification, it's safe now to use more CUDA threads
    to gain more parallelism.

Detail discussion: https://discuss.tvm.apache.org/t/topi-cuda-scatter-nd-has-a-very-poor-performance-on-cuda-backend-1000x-slower-than-hand-written-cuda-code/10426

@zhuwenxi
Copy link
Contributor Author

@tkonolige Could you help review this PR? Thank you.

1. Split into 2 kernels, one does the "Init" and another does the "Update".
   Thus they can have different Grid/Block configurations to better utilize
   SMs.
2. Use atomic_add instead of direct assignment, which could avoid the race
   condtion when multiple indices point to the same location of the output
   tensor. With this moidification, it's safe now to use more CUDA threads
   to gain more parallelism.
@zhuwenxi zhuwenxi force-pushed the feature/wenxizhu/improve-scatter-performance branch from 4107191 to 930043e Compare July 15, 2021 09:32
@zhuwenxi zhuwenxi changed the title [TOPI][CUDA] Improve the performance of scatter_nd by: [TOPI][CUDA] Improve the performance of scatter_nd Jul 15, 2021
Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks zhuwenxi! Do you have performance numbers for the PR? I'd be interested in seeing them.

blockDim = data_ptr.shape[-1]

ib.scope_attr(bidx, "thread_extent", gridDim)
ib.scope_attr(tidx, "thread_extent", blockDim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases this dimension will be very small. Can you instead split the full shape by max_num_threads?

with ib.new_scope():
bidx = te.thread_axis("blockIdx.x")
tidx = te.thread_axis("threadIdx.x")
gridDim = fused_indices_dimension # 32 * 600 = 19200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
findex = j
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've set j = tidx and then only use it in one spot. Why not just use tidx everywhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines -790 to -795
# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about how we are doing parallelism (we are thread-parallel over all the update dimension and each block handles one set of indices?)

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the original parallelism scheme, but replace ib.for_range() with blockIdx.y.
Atomic_add guarantees correctness when mode=="add"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment in the code to reflect this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

- Split ScatterND kernel into 2 sub-kernels using ib.new_scope()

- Replace ib.for_range() with blockIdx.y

- Using atomic_add when mode == "add"

- Keep threadIdx.x less than max_threads of GPU
@zhuwenxi
Copy link
Contributor Author

zhuwenxi commented Jul 20, 2021

@tkonolige about the performance comparison, it's 23 ms vs. 4.9 ms on my NV T4 card, for the case I provided in https://discuss.tvm.apache.org/t/topi-cuda-scatter-nd-has-a-very-poor-performance-on-cuda-backend-1000x-slower-than-hand-written-cuda-code/10426.

@zhuwenxi
Copy link
Contributor Author

@tkonolige We just upstream a commit to fix a UT and the comment issue. The remaining fixes are on the way.

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide timing information for a variety of shapes and ranks. I just want to make sure this is faster on all inputs.

offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
out[index] = updates[by * fused_updates_dimension + j]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move updates[by * fused_updates_dimension + j] outside of the if statements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines -790 to -795
# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment in the code to reflect this?

@CaptainDuke
Copy link
Contributor

CaptainDuke commented Jul 21, 2021

Could you provide timing information for a variety of shapes and ranks. I just want to make sure this is faster on all inputs.

ScatterND_performance

@tkonolige
We evalutate the performance with 3 types of ranks and shapes. Time (nanoseconds) is collected using Nsight System.

So long as the original with ib.for_range() as i is large enough, the separated two kernels would enlarge dimGrid and achieve better parallelism significantly.

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance results look great! Could you also test 1. where indices is small (~10) and updates is large and 2. where indices is large and updates is size 1.

Comment on lines 794 to 796
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is no longer valid right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted


# For better performance, we introduce blockIdx.y to implement for-loops
# within one thread.
# Atomic_add guarantees correctness when mode=="add"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Atomic_add guarantees correctness when mode=="add"
# The code is parallel over the scattered indices, so we use atomic_add to guarantee correctness when mode=="add".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines -815 to -817
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep this comment. I believe it still holds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@CaptainDuke
Copy link
Contributor

Performance results look great! Could you also test 1. where indices is small (~10) and updates is large and 2. where indices is large and updates is size 1.

ScatterND_test2

@tkonolige
2 cases tested.

@CaptainDuke
Copy link
Contributor

CaptainDuke commented Jul 22, 2021

@tkonolige
We found that some test cases failed since automic_add from CUDA doesn't support int64 data type, so we add fallback implementation to pass these test cases.

Do you have any suggestions on this fallback?

- Atomic_add from CUDA doesn't support int64 data type
- Change "ind{i}" to "ind%d"%i, where names of relay.var could correctly display
@CaptainDuke CaptainDuke force-pushed the feature/wenxizhu/improve-scatter-performance branch from 83e8aa6 to 1faa97a Compare July 26, 2021 08:06
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
with ib.new_scope():
if updates.dtype == "int64" and mode == "add":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the correct way to check for atomic add support. @masahi What is the correct way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately there is not a good way. I think we should encode atomic support information to a target description (similarity to @Lunderberg's vulkan work)

For now, atomic is not supported by vulkan and metal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree, the target would be the best location for checking atomic support. I have it on my to-do list to document/RFC which parameters should be standardized across target kinds, so that they'll be available for use in strategies/optimizations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaptainDuke you need to check for Vulkan or metal here. Can you also add a comment as to why we have this if statement.

@CaptainDuke
Copy link
Contributor

CaptainDuke commented Jul 27, 2021

@tkonolige
I have committed several times but CI failed at different stages. Any suggestions in view of this situation?

More over, one test case error I can not reproduce, jenkins log: test_gru(), which seems non-related to this PR

Above results are computed on CPU, since target and device were hardcode #8565

@tkonolige
Copy link
Contributor

@CaptainDuke CI has been having issues. Just push a new commit and it will re-run.

Comment on lines 836 to 837
bdim_x = ceil_div(fused_updates_dimension, tdim)
bdim_y = fused_indices_dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For large input sizes, this creates too many blocks. Try with input sizes data=21, 3, 2600, 212, indices=4, 1, 1, 2600, 212, updates=1, 1, 2600, 212.

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

840                bx = te.thread_axis("blockIdx.y")
841                by = te.thread_axis("blockIdx.x")

According to Maximum x-dimension of a grid of thread blocks = 2^31 -1, I exchange blockIdx.x and blockIdx.y to avoid out of bounds.

For the given input sizes, performance on mode="add" was

time: 784638840
grid=(1,1,1), block=(1,1,1)

v.s.

time: 105102068 + 2141897 = 107243965.
grid=(34725600,1,1), block=(1,1,1)
grid=(551200,1,1), block=(1,1,1)

7.3x faster

index += offset * indices[by + l * fused_indices_dimension]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[up_index]
Copy link
Member

@masahi masahi Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For update mode, does this give deterministic output? To me it seems it doesn't.

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is calculated via the following equation:

  output = np.copy(data)
  update_indices = indices.shape[:-1]
  for idx in np.ndindex(update_indices):
      output[indices[idx]] = updates[idx]

The order of iteration in the above loop is not specified. In particular, indices should not have duplicate entries: that is, if idx1 != idx2, then indices[idx1] != indices[idx2]. This ensures that the output value does not depend on the iteration order.

@masahi
According to the defination of ScatterND in ONNX, output does not depend on the iteration order.

Based on the above assumption, we replace the original with ib.for_range(0, fused_indices_dimension) as i with blockIdx.y, where bim_y = fused_indices_dimension

Is this what you concerned ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter what ONNX says. If the previous implementation gives a deterministic output, performance improvement shouldn't break that. If you use atomic for add mode, then I assume that multiple threads compete for the same write index. This leads to non-determinstic output for update mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi
I see. So, should I fallback to previous algorithm when mode="update"? Or any suggestions. Thanks.

Copy link
Member

@masahi masahi Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for update the prev algo should be used.

If you do care about the performance improvement for update mode, we can add a new attribute allow_non_deterministic to scatter_nd op, which is False by default. And change ONNX frontend to emit scatter_op with allow_non_deterministic = True, which will allow the new code path for update mode as well. I think we can also choose this option if @tkonolige thinks this is reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll fallback to previous algo for update.
For allow_non_deterministic feature, maybe we could fire a new PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that sounds good, we can discuss with more people then. This has been on my mind for a while, since both our scatter and scatter_nd op sacrifice performance for deterministic output, while all other frameworks make the opposite choice (they say output is undefined when indices are not unique).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Looking forward to further improvement

Comment on lines 791 to 792
# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not valid anymore right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
with ib.new_scope():
if updates.dtype == "int64" and mode == "add":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaptainDuke you need to check for Vulkan or metal here. Can you also add a comment as to why we have this if statement.

Comment on lines 840 to 841
bx = te.thread_axis("blockIdx.y")
by = te.thread_axis("blockIdx.x")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the names so they match the dimensions? Alternatively rename them to reflect what they are indexing over.

Copy link
Contributor

@CaptainDuke CaptainDuke Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkonolige @masahi Check for vulkan & metal added.
Comment added.
Names updated.

# For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64"
# So we fallback to normal algorithm, using "+=" rather than atomic_add

# TODO:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put a username on the TODO (your username assuming you will do this).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

python/tvm/topi/cuda/scatter.py Outdated Show resolved Hide resolved
@CaptainDuke
Copy link
Contributor

@tkonolige @masahi
All checks have been passed. Ready to merge

mode == "update"
or cur_target_kind("vulkan")
or cur_target_kind("metal")
or (updates.dtype == "int64" and mode == "add")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think atomic is only supported for 32 bit. So float64 or int16 should also be catched here. Also since now you have mode == "update check above, there is no need to check mode == "add".

I suggest swapping then and else block and make the condition be:

if mode == "add" and target not in ["vulkan", metal"] and updates.dtype not in ["int32", "float32"]:
   use atomic code path
else
   ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Tiny fix: updates.dtype in ["int32", "float32]: rather than not in

# Since multiple threads compete for the same write index, which leads to
# non-determinstic output for update mode. We could add a new attribute
# "allow_non_deterministic" to scatter_nd op, which is False by default.
# And change ONNX frontend to emit scatter_op with allow_non_deterministic = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the reference to "ONNX".

We could add a new attribute,  "allow_non_deterministic", which can be conditionally set to True by each frontend when non-determinsm is allowed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -764,6 +764,9 @@ def scatter_nd(data, indices, updates, mode):
"""
_verify_scatter_nd_inputs(data, indices, updates)

def cur_target_kind(kind="cuda"):
return tvm.target.Target.current(allow_none=False).kind == tvm.target.Target(kind).kind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tvm.target.Target.current(allow_none=False).kind == kind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

I use str(tvm.target.Target.current(allow_none=False).kind)
convert <class 'tvm.target.target.TargetKind'> to string

@masahi masahi merged commit 887324f into apache:main Aug 1, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
* [TOPI][CUDA] Improve the performance of scatter_nd by:

1. Split into 2 kernels, one does the "Init" and another does the "Update".
   Thus they can have different Grid/Block configurations to better utilize
   SMs.
2. Use atomic_add instead of direct assignment, which could avoid the race
   condtion when multiple indices point to the same location of the output
   tensor. With this moidification, it's safe now to use more CUDA threads
   to gain more parallelism.

* Fix python code format.

* FIX: [TOPI][CUDA] Improve the performance of scatter_nd apache#8479

- Split ScatterND kernel into 2 sub-kernels using ib.new_scope()

- Replace ib.for_range() with blockIdx.y

- Using atomic_add when mode == "add"

- Keep threadIdx.x less than max_threads of GPU

* Comment added

* Add fallback implementation when "mode=add" meets int64

- Atomic_add from CUDA doesn't support int64 data type
- Change "ind{i}" to "ind%d"%i, where names of relay.var could correctly display

* Python format

* Fix line too long

* CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Exchange blockIdx.x and blockIdx.y

* check for Vulkan or metal

* Fallback to previous algorithm when mode==update

* Update python/tvm/topi/cuda/scatter.py

Co-authored-by: Tristan Konolige <[email protected]>

* Assign TODO

* Swapping then and else block

Co-authored-by: wenxizhu <[email protected]>
Co-authored-by: CaptainDuke <[email protected]>
Co-authored-by: Tristan Konolige <[email protected]>
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* [TOPI][CUDA] Improve the performance of scatter_nd by:

1. Split into 2 kernels, one does the "Init" and another does the "Update".
   Thus they can have different Grid/Block configurations to better utilize
   SMs.
2. Use atomic_add instead of direct assignment, which could avoid the race
   condtion when multiple indices point to the same location of the output
   tensor. With this moidification, it's safe now to use more CUDA threads
   to gain more parallelism.

* Fix python code format.

* FIX: [TOPI][CUDA] Improve the performance of scatter_nd apache#8479

- Split ScatterND kernel into 2 sub-kernels using ib.new_scope()

- Replace ib.for_range() with blockIdx.y

- Using atomic_add when mode == "add"

- Keep threadIdx.x less than max_threads of GPU

* Comment added

* Add fallback implementation when "mode=add" meets int64

- Atomic_add from CUDA doesn't support int64 data type
- Change "ind{i}" to "ind%d"%i, where names of relay.var could correctly display

* Python format

* Fix line too long

* CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Empty, for CI pass

* Exchange blockIdx.x and blockIdx.y

* check for Vulkan or metal

* Fallback to previous algorithm when mode==update

* Update python/tvm/topi/cuda/scatter.py

Co-authored-by: Tristan Konolige <[email protected]>

* Assign TODO

* Swapping then and else block

Co-authored-by: wenxizhu <[email protected]>
Co-authored-by: CaptainDuke <[email protected]>
Co-authored-by: Tristan Konolige <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants