-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve primitives for FP6 quant #248
Merged
Merged
Changes from 86 commits
Commits
Show all changes
90 commits
Select commit
Hold shift + click to select a range
97924d7
add fp16_to_fp6 prototype
gau-nernst 8bf081c
minor rename
gau-nernst 558f4e4
Merge branch 'pytorch:main' into fp6_quant
gau-nernst 314e9f6
fix rounding issue
gau-nernst 030b956
Merge branch 'pytorch:main' into fp6_quant
gau-nernst 79ce0db
update quant
gau-nernst 45a92f3
add unpacked version
gau-nernst a8555e3
remove unnecessary comment
gau-nernst 012176e
add CUDA version
gau-nernst d4b8681
add fp6 packed cpu
gau-nernst f0f3101
add CUDA for packed
gau-nernst f542eb1
some rename
gau-nernst 40dc725
update name
gau-nernst 3a98874
Merge branch 'main' into fp6_quant
gau-nernst eef2f95
add OpenMP
gau-nernst f61aa37
fix CUDA bug
gau-nernst 1640bbf
add fp6->fp16
gau-nernst 7a00b31
add FP6->FP32
gau-nernst b2fcc6c
move files around
gau-nernst d9ca476
rearrange stuff
gau-nernst ba89a0b
add more things
gau-nernst faf8682
Merge branch 'pytorch:main' into fp6_quant
gau-nernst e7b3135
update
gau-nernst 8b3ac04
update. add comments
gau-nernst 0635882
some rename. add some tests
gau-nernst 4240692
add fp32->fp6 unpacked
gau-nernst 7eb6fa8
fix
gau-nernst 1c0e401
Merge branch 'main' into fp6_quant
gau-nernst 26669b6
use template. add BF16
gau-nernst e09b61f
use template
gau-nernst 887bac2
simplify API. add BF16 support via templates
gau-nernst 4b5c99f
typo
gau-nernst 39f9dce
enable OpenMP via compile flags
gau-nernst b681ae1
add memory access optimized version (though it is not faster..)
gau-nernst 7c5fcd3
use fp32 mul impl for CUDA
gau-nernst 82e4e60
add test case
gau-nernst fb18c73
typo. remove OpenMP since we cannot throw exception
gau-nernst a3c5e36
fix rounding for subnormal
gau-nernst 27781e5
add to_fp6_value()
gau-nernst 7d9dd34
simplify to_fp6_unpacked_cuda
gau-nernst 7fb8c8b
simplify to_fp6_packed_cuda
gau-nernst 965838c
clean up CPU impl
gau-nernst a64421e
add FP6->FP16/BF16
gau-nernst a4b7c7a
add dim check
gau-nernst 3e4c1c1
add qtorch to dev req
gau-nernst 632af93
handle exception with OpenMP
gau-nernst 6c6fe83
handle exception in OpenMP
gau-nernst cb08b37
add tests
gau-nernst 9f94030
more tests
gau-nernst 7b7e823
simplify test
gau-nernst 7c1ff7d
rename
gau-nernst 0472b06
add back checks
gau-nernst 0bda927
update docs
gau-nernst a21837c
add pure pytorch impl
gau-nernst 6101869
add benchmark
gau-nernst 81d7aeb
Merge branch 'main' into fp6_quant
msaroufim 4c1da5f
Merge branch 'main' into fp6_quant
gau-nernst df7932b
update benchmark script
gau-nernst bdbd907
add triton kernel
gau-nernst 42bf771
remove CUDA kernel
gau-nernst f178b01
move to_fp6 to dtypes/
gau-nernst ee2310c
add to_fp6 import
gau-nernst 404f700
move tests
gau-nernst 48fe45f
update benchmark script
gau-nernst 5126f8f
add from_fp6
gau-nernst da767fa
migrate test
gau-nernst 0b56ecf
add docs
gau-nernst 110e888
add docs
gau-nernst 3e2643c
add torch.compile test
gau-nernst 71ecc45
Merge branch 'main' into fp6_quant
msaroufim 750fbc6
polish docs
gau-nernst 6a3f0c0
remove original weight dequant
gau-nernst f32d09f
remove weight dequant
gau-nernst 8b5b81e
improve tests
gau-nernst a3cf93b
update names
gau-nernst 3c636ff
rename
gau-nernst f672c70
update names
gau-nernst 1a310e3
add notes about denormal numbers
gau-nernst c9ec255
update note
gau-nernst d1697e7
Merge branch 'main' into fp6_quant
gau-nernst 8c86028
Merge branch 'main' into fp6_quant
gau-nernst d24dba8
fix merge problem
gau-nernst ce5dac1
fix merge conflict
gau-nernst 922446d
add to_fp6 CPU C++ kernel
gau-nernst d287eb3
add from_fp6 cpu C++
gau-nernst ce7e09a
rename
gau-nernst 22007a1
add some comments
gau-nernst f97421a
small cleanup
gau-nernst f727de0
always use uint32_t for bit manipulation
gau-nernst 78e79ac
simplify test
gau-nernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 | ||
|
||
|
||
_DTYPES = [torch.float32, torch.float16, torch.bfloat16] | ||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
class TestFp6(TestCase): | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("dtype", _DTYPES) | ||
@parametrize( | ||
"input_output", | ||
[ | ||
(0.0, 0b000000), # exact values | ||
(1.0, 0b001100), # normal numbers | ||
(1.25, 0b001101), | ||
(28.0, 0b011111), # max | ||
(0.1875, 0b000011), # subnormal number | ||
(0.0625, 0b000001), # min | ||
(29.0, 0b011111), # normal round down | ||
(26.0, 0b011110), # normal round to nearest even | ||
(0.1251, 0b000010), # subnormal round down | ||
(0.0314, 0b000001), # subnormal round up | ||
(0.03, 0b000000), # underflow | ||
], | ||
) | ||
def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): | ||
input, output = input_output | ||
input = torch.tensor(input, device=device, dtype=dtype) | ||
assert to_float6_e3m2(input, no_bit_packing=True).item() == output | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("dtype", _DTYPES) | ||
def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): | ||
x = torch.randn(128, 128, device=device, dtype=dtype) | ||
results_unpacked = to_float6_e3m2(x, no_bit_packing=True) | ||
results_packed = to_float6_e3m2(x) | ||
|
||
val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) | ||
bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 | ||
bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 | ||
bits2 = (val2 << 6) | (val3); # 2233 3333 | ||
|
||
expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) | ||
assert (results_packed == expected_packed).all() | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("shape", [(), (0,), (10,), (20, 20)]) | ||
def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): | ||
x = torch.randn(shape, device=device) | ||
result = to_float6_e3m2(x, no_bit_packing=True) | ||
assert result.shape == shape | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("shape", [(4,), (20, 20)]) | ||
def test_to_float6_e3m2_bit_packing_shape(self, device, shape): | ||
x = torch.randn(shape, device=device) | ||
result = to_float6_e3m2(x) | ||
assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("dtype", _DTYPES) | ||
@parametrize("no_bit_packing", [False, True]) | ||
def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): | ||
x = torch.randn(20, 20, device=device, dtype=dtype) | ||
expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) | ||
|
||
to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) | ||
actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize( | ||
"input_output", | ||
[ | ||
(0b000000, 0.0), | ||
(0b001100, 1.0), | ||
(0b011111, 28.0), | ||
(0b000001, 0.0625), | ||
(0b001110, 1.5), | ||
(0b000011, 0.1875), | ||
], | ||
) | ||
def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): | ||
input, output = input_output | ||
input = torch.tensor(input, device=device, dtype=torch.uint8) | ||
assert from_float6_e3m2(input, no_bit_packing=True).item() == output | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_from_float6_e3m2_bit_packing_correctness(self, device): | ||
x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) | ||
actual = from_float6_e3m2(x) | ||
|
||
bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) | ||
x_unpacked0 = bits0 >> 2 | ||
x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) | ||
x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) | ||
x_unpacked3 = bits2 & 0x3F | ||
|
||
x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) | ||
expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@parametrize("device", _DEVICES) | ||
@parametrize("no_bit_packing", [False, True]) | ||
def test_from_float6_e3m2_compile(self, device, no_bit_packing): | ||
x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) | ||
expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) | ||
|
||
from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) | ||
actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
instantiate_parametrized_tests(TestFp6) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,6 @@ | |
// limitations under the License. | ||
// | ||
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h | ||
// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h | ||
|
||
#include <cuda_fp16.h> | ||
#include <iostream> | ||
|
@@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, | |
} | ||
} | ||
|
||
void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced with |
||
assert(M%64==0); // Currently, M must be a multiple of 64. | ||
assert(K%64==0); // Currently, K must be a multiple of 64. | ||
size_t TotalSizeInByte = M*K*6/8; | ||
// | ||
half* OutPTR = A_16bit_h; | ||
for(size_t i=0; i<TotalSizeInByte/3; i++) { // 4 FP6 = 3 Bytes for each Loop | ||
unsigned char B1 = A_6bit_h[i*3+0] & 0xfc; | ||
B1 = (B1&0x80) | ((B1>>2)&0x1f); | ||
unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); | ||
B2 = (B2&0x80) | ((B2>>2)&0x1f); | ||
unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); | ||
B3 = (B3&0x80) | ((B3>>2)&0x1f); | ||
unsigned char B4 = A_6bit_h[i*3+2]<<2; | ||
B4 = (B4&0x80) | ((B4>>2)&0x1f); | ||
half FP1, FP2, FP3, FP4; | ||
unsigned char *PTR1, *PTR2, *PTR3, *PTR4; | ||
PTR1 = reinterpret_cast<unsigned char*>(&FP1); | ||
PTR2 = reinterpret_cast<unsigned char*>(&FP2); | ||
PTR3 = reinterpret_cast<unsigned char*>(&FP3); | ||
PTR4 = reinterpret_cast<unsigned char*>(&FP4); | ||
PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU | ||
PTR2[0] = 0; PTR2[1] = B2; | ||
PTR3[0] = 0; PTR3[1] = B3; | ||
PTR4[0] = 0; PTR4[1] = B4; | ||
OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); | ||
OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); | ||
OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); | ||
OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); | ||
// | ||
OutPTR +=4; | ||
} | ||
} | ||
|
||
|
||
#include <torch/extension.h> | ||
#include <ATen/ATen.h> | ||
#include <torch/library.h> | ||
|
||
namespace torchao { | ||
|
||
// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 | ||
at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) | ||
at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor) | ||
{ | ||
TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); | ||
TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); | ||
|
@@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) | |
return packed_fp6_tensor; | ||
} | ||
|
||
/* | ||
* Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. | ||
* A useful tool to construct input matrices for the FP16 GEMM baseline. | ||
* [Input] | ||
* fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. | ||
* fp16_scale: half tensor of shape [OC]; // for row-wise quantization. | ||
* [Output] | ||
* fp16_tensor: half tensor of shape [OC, IC]. | ||
*/ | ||
at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) | ||
{ | ||
int OC = fp6_tensor.size(0); | ||
TORCH_CHECK(fp6_tensor.size(1) % 3 == 0); | ||
int IC = fp6_tensor.size(1) / 3 * 16; | ||
TORCH_CHECK(fp16_scale.size(0) == OC); | ||
// | ||
auto fp6_tensor_ptr = reinterpret_cast<unsigned char*>(fp6_tensor.data_ptr<int>()); | ||
auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>()); | ||
// | ||
auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device()); | ||
at::Tensor fp16_tensor = at::empty({OC, IC}, options); | ||
auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>()); | ||
// | ||
DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr); | ||
// | ||
return fp16_tensor; | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(torchao, CPU, m) { | ||
m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); | ||
m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); | ||
m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); | ||
} | ||
|
||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to import
_C
first sinceto/from_float6_e3m2()
(fromdtypes
) calls C++ extension for CPU.