Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Minor perf improvement for GPU scatter #7233

Merged
merged 16 commits into from
Jan 19, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 8, 2021

This updates GPU scatter in two ways, to improve performance on GPU MaskRCNN (should be better for other workloads).

  • PyTorch frontend uses 1D scatter of one element to emulate inplace assignment arr[i] = v,

    end = _op.scatter(
    end,
    _op.expand_dims(_expr.const(dim), axis=0),
    _op.expand_dims(target_end, axis=0),
    axis=0,
    )
    Using sorting based scatter involves too much overhead for such small inputs. For small inputs, sequential scatter is better. The size threshold was chosen arbitrary (50).

  • The first kernel (initialization) of 4D scatter turns out very slow. It is actually much slower than the second sequential kernel, taking more than 10 milli sec of MaskRCNN runs as shown in the profile and trace below. It's likely the performance depends on input shape, but I found the way threading is done a bit strange. This PR changes the threading of the first kernel to be the same as other injective ops, to scale better regardless of input shapes.

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   14.56%  31.765ms        58  547.67us  28.160us  1.0697ms  fused_expand_dims_concatenate_1_kernel0
                    8.46%  18.457ms       483  38.213us     512ns  5.2853ms  [CUDA memcpy HtoD]
                    7.38%  16.101ms         4  4.0253ms  4.0085ms  4.0364ms  fused_nn_conv2d_add_nn_relu_20_kernel0
                    7.24%  15.799ms         1  15.799ms  15.799ms  15.799ms  fused_nn_conv2d_transpose_add_nn_relu_kernel0
                    6.85%  14.952ms      2287  6.5370us     640ns  7.9143ms  [CUDA memcpy DtoH]
                    5.34%  11.648ms         4  2.9121ms  2.9034ms  2.9217ms  fused_scatter_1_kernel0
                    3.84%  8.3777ms         1  8.3777ms  8.3777ms  8.3777ms  fused_nn_conv2d_add_nn_relu_16_kernel0
                    3.71%  8.0883ms         1  8.0883ms  8.0883ms  8.0883ms  fused_nn_conv2d_add_5_kernel0
                    3.25%  7.1009ms         2  3.5504ms  1.2477ms  5.8531ms  fused_dyn_full_kernel0
                    2.23%  4.8683ms         2  2.4341ms  365.28us  4.5030ms  sgemm_128x128x8_NT_vec
...
30.5532s  2.9201ms         (1000 256 7)        (32 1 1)         8        0B        0B         -           -           -           -  GeForce GTX 107         1         7  fused_scatter_1_kernel0 [3059]
30.5561s  519.65us            (1 256 7)        (32 1 1)        16        0B        0B         -           -           -           -  GeForce GTX 107         1         7  fused_scatter_1_kernel1 [3061]
...

These changes are not big deal, but it brings a good speed up on MaskRCNN: it cuts MaskRCNN runtime by 20 milli sec.

please review @mbrookhart @tkonolige

@tkonolige
Copy link
Contributor

tkonolige commented Jan 8, 2021

Would it be a better idea to have to separate scatter implementations (the parallel one and the sequential one) and let autotvm figure out which is better? Then we don't have to have all this special casing and magic input sizes.

Do you also have some benchmarks you could show for these changes? (I'm not clear what the second text block is showing)

@masahi
Copy link
Member Author

masahi commented Jan 8, 2021

The second text block is an excerpt from the output of nvprof --print-gpu-trace, showing elapsed time, launch config etc of each kernel executed, in order. The first line is for the initialization kernel, the second one the actual scatter kernel.

I don't have benchmark other than the data from MaskRCNN. For the first kernel of 4D scatter, since it is just a memcpy, I don't see why we should do threading differently than other injective ops. I hope we don't need thorough benchmarking to justify this change. After this change, the trace becomes (only the first line changes, note the elapsed time and thread launch config).

31.2518s  495.68us          (12250 1 1)      (1024 1 1)         8        0B        0B         -           -           -           -  GeForce GTX 107         1         7  fused_scatter_1_kernel0 [2980]
31.2523s  522.78us            (1 256 7)        (32 1 1)        16        0B        0B         -           -           -           -  GeForce GTX 107         1         7  fused_scatter_1_kernel1 [2982]

Would it be a better idea to have to separate scatter implementations (the parallel one and the sequential one) and let autotvm figure out which is better? Then we don't have to have all this special casing and magic input sizes.

hmm, this sounds better than picking a random threshold, but do we have existing uses of autotvm to make such decision? Given that scatter kernels are extern, I'm not sure if autotvm can work with them.

@tkonolige
Copy link
Contributor

hmm, this sounds better than picking a random threshold, but do we have existing uses of autotvm to make such decision? Given that scatter kernels are extern, I'm not sure if autotvm can work with them.

Autotvm does this for external libraries which are all extern, so it will work here.

I trust you when you say these are faster, I just wondered if you had done any benchmarking. Looking at the code, it seems it should be equally as fast, but sometimes it surprises you. That is when benchmarks are useful.

@masahi
Copy link
Member Author

masahi commented Jan 8, 2021

Yes, there are 4 calls to 4D scatter in MaskRCNN, the old kernel was taking 11.6 milli seconds on them in total, making it one of the bottlenecks as shown in the profile above. This change brings it down to 1.9873 milli seconds total and it is no longer a bottleneck. So this is a solid improvement.

I think the reason the old kernel was slow for this input (1000, 256, 7, 7) is because thread block is too small (32, 1, 1) and we are launching too many of them (1000 * 256 * 7 blocks).

@masahi
Copy link
Member Author

masahi commented Jan 8, 2021

Autotvm does this for external libraries which are all extern, so it will work here

@tkonolige I like the idea of separating sorting based implementation of scatter, so I want to try this. Can you point me where in the codebase autotvm deals with external libs? I found something like @autotvm.register_topi_compute("conv2d_cudnn.cuda") but I'm not sure how it interacts with normal conv2d ops.

One issue is that currently sorting based approach is only implemented for 1D scatter. For higher dimensions, I think sorting based approach is a bad idea. So dispatching decision needs to take input dimension into account (not sure if this could be a problem for autotvm or relay strategy).

@tkonolige
Copy link
Contributor

tkonolige commented Jan 8, 2021

@masahi Here is an example of having multiple implementations for the same op, with some of them being external. https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/x86.py#L371-L393 In this example, schedule_dense_cblas is just schedule_extern (

def schedule_dense_cblas(_, outs):
).

You can conditionally call strategy.add_implementation based on the input sizes (also in this example).

@masahi masahi marked this pull request as draft January 9, 2021 00:02
@masahi masahi force-pushed the gpu-scatter-improvement branch 2 times, most recently from ad00c94 to 57fc2d8 Compare January 18, 2021 13:30
@masahi masahi marked this pull request as ready for review January 18, 2021 13:34
@masahi
Copy link
Member Author

masahi commented Jan 19, 2021

@tkonolige @mbrookhart I separated the two scatter implementations, things should look clean now. The sequential one is chosen by default, and I confirmed that by tuning the scatter op the parallel one can be chosen.

Tuning the scatter op revealed an interesting issue in AutoTVM, discussed in https://discuss.tvm.apache.org/t/autotvm-cuda-runtime-error-when-tuning-extern-ops/8832/7. Thanks @FrozenGene for help.

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Seems like splitting it into two implementations made things cleaner.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mbrookhart mbrookhart merged commit 2290cc0 into apache:main Jan 19, 2021
@mbrookhart
Copy link
Contributor

Thanks @masahi @tkonolige @FrozenGene

TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* improve scatter 4d init

* do not launch sorting based scatter for small input

* do not use hard coded num threads

* separate sort based implementation

* register scatter as autotvm task

* add missing import

* fix strategy

* add dedicated schedule and dummy flop

* add test tuning script

* try adding dummy knob

* skip random_fill when a tuning workload is from scatter

This reverts commit 1fed883.

* cleanup memcpy ir

* remove scatter tuning script

* make sure zero init arguments

* add comment on why skip random init for scatter

* restore ctx sync

Co-authored-by: masa <[email protected]>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* improve scatter 4d init

* do not launch sorting based scatter for small input

* do not use hard coded num threads

* separate sort based implementation

* register scatter as autotvm task

* add missing import

* fix strategy

* add dedicated schedule and dummy flop

* add test tuning script

* try adding dummy knob

* skip random_fill when a tuning workload is from scatter

This reverts commit 1fed883.

* cleanup memcpy ir

* remove scatter tuning script

* make sure zero init arguments

* add comment on why skip random init for scatter

* restore ctx sync

Co-authored-by: masa <[email protected]>
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* improve scatter 4d init

* do not launch sorting based scatter for small input

* do not use hard coded num threads

* separate sort based implementation

* register scatter as autotvm task

* add missing import

* fix strategy

* add dedicated schedule and dummy flop

* add test tuning script

* try adding dummy knob

* skip random_fill when a tuning workload is from scatter

This reverts commit 1fed883.

* cleanup memcpy ir

* remove scatter tuning script

* make sure zero init arguments

* add comment on why skip random init for scatter

* restore ctx sync

Co-authored-by: masa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants