Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor custom FPx cast #363

Merged
merged 11 commits into from
Jun 17, 2024
Merged

Refactor custom FPx cast #363

merged 11 commits into from
Jun 17, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jun 14, 2024

Closes #354

TODO:

  • Check torch.compile
  • Benchmark before and after
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py

8841094 (main)

elem_dtype use_fp4_custom_triton_dequant_kernel q_time_us q_mem_bw_tb_s dq_time_us dq_mem_bw_tb_s
torch.float8_e4m3fn False 532.20 0.26 554.40 0.25
torch.float8_e5m2 False 532.83 0.26 551.01 0.25
fp6_e2m3 False 574.66 0.24 258.41 0.53
fp6_e3m2 False 577.37 0.24 258.86 0.53
fp4_e2m1 False 682.13 0.17 254.72 0.45
fp4_e2m1 True 12251.34 0.01 190.39 0.60

2690b92 (this PR)

elem_dtype use_fp4_custom_triton_dequant_kernel q_time_us q_mem_bw_tb_s dq_time_us dq_mem_bw_tb_s
torch.float8_e4m3fn False 531.62 0.26 552.91 0.25
torch.float8_e5m2 False 530.52 0.26 550.33 0.25
fp6_e2m3 False 572.89 0.24 551.21 0.25
fp6_e3m2 False 576.62 0.24 551.92 0.25
fp4_e2m1 False 680.27 0.17 255.10 0.45
fp4_e2m1 True 12248.68 0.01 191.09 0.60

Dequant is 2x slower because I replaced LUT-based denormal handling with a more generic logic. @vkuzo Should I add back the LUT-based logic (check specifically for E2M3 E3M2 E2M1)? If we are interested in performance then perhaps we can generate a LUT for all bit patterns and cache it.

UPDATE

95f4582 (this PR v2)

elem_dtype use_fp4_custom_triton_dequant_kernel q_time_us q_mem_bw_tb_s dq_time_us dq_mem_bw_tb_s
torch.float8_e4m3fn False 531.54 0.26 554.03 0.25
torch.float8_e5m2 False 532.41 0.26 551.63 0.25
fp6_e2m3 False 574.66 0.24 258.28 0.53
fp6_e3m2 False 576.53 0.24 258.76 0.53
fp4_e2m1 False 682.26 0.17 517.65 0.22
fp4_e2m1 True 12247.99 0.01 190.48 0.60

Now FP4_E2M1 is slower lol. Feel like this should be bandwidth-limited. It might be register-limited also? Will do some profiling + make sure torch.compile run optimally. Interesting that native PyTorch float8 dequant is slower.

UPDATE 2

dcd5a05 (this PR v3)

elem_dtype use_fp4_custom_triton_dequant_kernel q_time_us q_mem_bw_tb_s dq_time_us dq_mem_bw_tb_s
torch.float8_e4m3fn False 532.65 0.26 554.13 0.25
torch.float8_e5m2 False 532.07 0.26 551.30 0.25
fp6_e2m3 False 574.76 0.24 258.48 0.53
fp6_e3m2 False 576.85 0.24 258.36 0.53
fp4_e2m1 False 681.22 0.17 254.14 0.45
fp4_e2m1 True 12249.79 0.01 190.49 0.60

Speed recovered 😊

Copy link

pytorch-bot bot commented Jun 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/363

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bd64efc with merge base 664f073 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2024
@gau-nernst gau-nernst requested a review from vkuzo June 15, 2024 00:58
@vkuzo
Copy link
Contributor

vkuzo commented Jun 15, 2024

Dequant is 2x slower because I replaced LUT-based denormal handling with a more generic logic.

2x is a sizeable regression, how about keeping the LUT for the formats we already have it for and having a generic fallback for the other formats? People can then optimize format by format individually if they want.

@gau-nernst
Copy link
Collaborator Author

gau-nernst commented Jun 15, 2024

@vkuzo I have updated the dequant denormal implementation. No speed regression anymore (I updated the results in the 1st post). Didn't need to use the hard-coded LUT from your implementation. If torch compiler does constant folding and loop unrolling properly, I think my implementation should match your previous implementation exactly.

If possible, you can benchmark on your GPUs to make sure 100% there is no regression.

@gau-nernst gau-nernst marked this pull request as ready for review June 15, 2024 07:35
@vkuzo
Copy link
Contributor

vkuzo commented Jun 17, 2024

Here are results on an H100: https://gist.github.com/vkuzo/324256b8defd0231852a23cbb34f49a6, I see no meaningful change in performance, awesome stuff


def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""
TODO(future): check if LUT for everything is faster than bit shifting,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment still relevant?

maybe add a docblock?

Copy link
Collaborator Author

@gau-nernst gau-nernst Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using LUT for everything in dequant might be faster, like current NF4 implementation. I haven't benchmarked so I'm not sure.
I didn't add a docblock here since I think this is kinda an internal function. But a simple doc won't hurt. Will add some doc for this and quant function above. I already added a short description for these 2 functions at the top of the file.

F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)


def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a docblock?

@msaroufim msaroufim merged commit eb1511e into pytorch:main Jun 17, 2024
13 checks passed
@gau-nernst gau-nernst deleted the custom_fpx branch June 17, 2024 15:21
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* refactor custom fp cast

* add dequant

* small formating

* compile with fullgraph=True

* add fullgraph=true

* undo

* add another version

* fast path for mbits=1

* add back docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make custom FPx dtype conversion easier to use
4 participants