Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve primitives for FP6 quant #248

Merged
merged 90 commits into from
May 25, 2024
Merged

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented May 16, 2024

Address #208

TODO:

  • FP32/FP16/BF16 -> FP6
  • FP6 -> FP32/FP16/BF16
  • Add tests

On (8192, 8192) tensor. Ryzen 5600 and 4070Ti SUPER

device dtype op time (m/s)
CPU FP16->FP6 original 1140.27
CPU FP16->FP6 ours 384.479
CPU FP16->FP6 original (num_threads=4) 977.523
CPU FP16->FP6 ours (num_threads=4) 98.3557
CPU FP32->FP6 original 1033.14
CPU FP32->FP6 ours 374.142
CPU FP32->FP6 original (num_threads=4) 934.211
CPU FP32->FP6 ours (num_threads=4) 95.7996
CUDA FP16->FP6 ours 0.325222
CUDA FP32->FP6 ours 0.639134

NOTE:

  • original is torchao.ops.fp16_to_fp6_original() (from original FP6-LLM repo + qtorch quantization logic). This does not support CUDA.
  • On CPU, there is a faster algorithm using only bit shift. But it cannot be implemented efficiently with PyTorch+torch.compile

(8192, 8192) FP6 input. Ryzen 5600 and 4070Ti SUPER.

device dtype op time (m/s)
CPU FP6->FP32 original 372.076
CPU FP6->FP32 ours 127.714
CPU FP6->FP32 original (num_threads=4) 375.183
CPU FP6->FP32 ours (num_threads=4) 44.1857
CUDA FP6->FP32 ours 0.572355

NOTE:

  • fp6_weight_dequant() (original implementation) is slow probably because the author use CUDA intrinsics __float2half() and __half2float() on CPU, which have to be implemented via bit manipulation.

Copy link

pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/248

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 78e79ac with merge base a7bc592 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 16, 2024
@gau-nernst gau-nernst mentioned this pull request May 21, 2024
return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2)


def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: thoughts about naming your dtype float6_e3m2 instead of fp6? This is to be consistent with naming for other PyTorch low precision dtypes such as float8_e4m3|e5m2 from PyTorch core as well as the upcoming MX dtypes, which include float6_e3m2 and float6_e2m3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing too! Will update the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where can I read more about MX dtypes? This particular FP6 used by FP6-LLM paper does not represent +/-inf and NaN, so not sure if we should signal that in the name somehow too? (like float8_e4m3fn)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check out https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, page 12 describes the supported float6 flavors. I plan to add the the mx code in torchao soon.

For the fn suffix...I'm planning to follow the OCP spec naming, which does not include naming qualifiers for special value handling, and replace fp with float to be consistent with other PyTorch dtype names. I think the fn suffix made sense for float8 where different flavors had different special value handling, but none of these sub 8 bit dtypes support special values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! It seems like the FP6 I used here is exactly the same as MX FP6 E3M2 (without the scale - FP6 LLM author use 1 scale per row). Perhaps in the future MX dtype can replace this.

@@ -14,6 +8,13 @@
from . import _C
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to import _C first since to/from_float6_e3m2() (from dtypes) calls C++ extension for CPU.

@@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,
}
}

void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with from_float6_e3m2()

@msaroufim msaroufim merged commit 4ca3985 into pytorch:main May 25, 2024
13 checks passed
@gau-nernst gau-nernst deleted the fp6_quant branch May 25, 2024 22:58
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants