Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split FIL infer_k into phases to speed up compilation (when a patch is applied) #4148

Merged
merged 8 commits into from
Aug 11, 2021

Conversation

levsnv
Copy link
Contributor

@levsnv levsnv commented Aug 5, 2021

FIL takes several minutes to compile every time even on release (up to 15mins on debug). When combined with an occasional larger recompile, it makes iterating on the code much slower. This code reduces the release compile time of infer.cu to 18s and probably a similar speedup on debug builds.

Some phases depend on fewer template parameters than the whole infer_k. If we merely avoid inlining those pieces, a lot of the code will no longer be duplicated 240 times (3 storage_type x 2 cols_in_shmem x 4 NITEMS x 5 leaf_algo x 2 branch_can_be_categorical).
Since those functions are called once per rows_per_block (and once per whole forest within that), the runtime overhead should be low enough. An empirical test confirms this.

We are keeping the default compilation as joint due to the theoretical uncertainty with function call overhead.

@levsnv levsnv requested a review from canonizer August 5, 2021 23:04
@levsnv levsnv added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Aug 5, 2021
@levsnv
Copy link
Contributor Author

levsnv commented Aug 6, 2021

the performance impact is within measurement noise:

Titan V measurements on storage_type_t::DENSE and algo_t::BATCH_TREE_REORG

best time @nitems = 0 slowdown % real n_items for __noinline__
n_rows n_cols n_classes depth n_trees __forceinline__ __noinline__
// higgs 1E6 28 2 12 700 48.49 ms 48.13 ms −1% 3
// covtype (here: binary) 1E6 54 2 9 700 37.93 ms 37.51 ms −1% 4
// year (regression) 1E6 90 1 10 700 50.91 ms 52.08 ms 2% 4
// bosch numeric 4E5 968 2 8 700 28.60 ms 27.90 ms −2% 3
// epsilon 2E5 2000 2 9 700 19.12 ms 19.55 ms 2% 1
tiny 1E6 4 2 3 100 1.09 ms 1.05 ms −4% 4

@levsnv levsnv marked this pull request as ready for review August 6, 2021 04:53
@levsnv levsnv requested a review from a team as a code owner August 6, 2021 04:53
Copy link
Contributor

@canonizer canonizer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved provided that the review comments are addressed.

cpp/src/fil/infer.cu Outdated Show resolved Hide resolved
cpp/src/fil/infer.cu Outdated Show resolved Hide resolved
@levsnv
Copy link
Contributor Author

levsnv commented Aug 6, 2021

rerun tests

2:08:16 Traceback (most recent call last):
22:08:16   File "/opt/conda/envs/rapids/lib/python3.7/site-packages/conda_package_handling/tarball.py", line 146, in extract
22:08:16     _tar_xf(fn, dest_dir)
22:08:16   File "/opt/conda/envs/rapids/lib/python3.7/site-packages/conda_package_handling/tarball.py", line 98, in _tar_xf
22:08:16     archive_utils.extract_file(tarball)
22:08:16   File "/opt/conda/envs/rapids/lib/python3.7/site-packages/conda_package_handling/archive_utils.py", line 15, in extract_file
22:08:16     raise InvalidArchiveError(tarball, error_str.decode('utf-8'))
22:08:16 conda_package_handling.exceptions.InvalidArchiveError: Error with archive /opt/conda/envs/rapids/pkgs/cudatoolkit-11.2.72-h2bc3f7f_0.tar.bz2.  You probably need to delete and re-download or re-create this file.  Message from libarchive was:

@levsnv levsnv changed the title Split FIL infer_k into phases to avoid compiling 240 instantiations of the whole kernel Split FIL infer_k into phases to speed up compilation (when a patch is applied) Aug 10, 2021
@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-21.10@e977f3e). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-21.10    #4148   +/-   ##
===============================================
  Coverage                ?   85.96%           
===============================================
  Files                   ?      232           
  Lines                   ?    18500           
  Branches                ?        0           
===============================================
  Hits                    ?    15904           
  Misses                  ?     2596           
  Partials                ?        0           
Flag Coverage Δ
dask 47.76% <0.00%> (?)
non-dask 78.56% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e977f3e...d869283. Read the comment docs.

@dantegd
Copy link
Member

dantegd commented Aug 11, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit b59bcd5 into rapidsai:branch-21.10 Aug 11, 2021
@levsnv levsnv deleted the uninline-simple branch August 11, 2021 19:51
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
…s applied) (rapidsai#4148)

FIL takes several minutes to compile every time even on release (up to 15mins on debug). When combined with an occasional larger recompile, it makes iterating on the code much slower. This code reduces the release compile time of `infer.cu` to 18s and probably a similar speedup on debug builds.

Some phases depend on fewer template parameters than the whole `infer_k`. If we merely avoid inlining those pieces, a lot of the code will no longer be duplicated 240 times (3 storage_type x 2 cols_in_shmem x 4 NITEMS x 5 leaf_algo x 2 branch_can_be_categorical).
Since those functions are called once per `rows_per_block` (and once per whole forest within that), the runtime overhead should be low enough. An empirical test confirms this.

We are keeping the default compilation as joint due to the theoretical uncertainty with function call overhead.

Authors:
  - https://github.com/levsnv

Approvers:
  - Andy Adinets (https://github.com/canonizer)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4148
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA/C++ improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants