Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve primitives for FP6 quant #248

Merged
merged 90 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
97924d7
add fp16_to_fp6 prototype
gau-nernst May 15, 2024
8bf081c
minor rename
gau-nernst May 15, 2024
558f4e4
Merge branch 'pytorch:main' into fp6_quant
gau-nernst May 16, 2024
314e9f6
fix rounding issue
gau-nernst May 16, 2024
030b956
Merge branch 'pytorch:main' into fp6_quant
gau-nernst May 16, 2024
79ce0db
update quant
gau-nernst May 16, 2024
45a92f3
add unpacked version
gau-nernst May 16, 2024
a8555e3
remove unnecessary comment
gau-nernst May 16, 2024
012176e
add CUDA version
gau-nernst May 16, 2024
d4b8681
add fp6 packed cpu
gau-nernst May 16, 2024
f0f3101
add CUDA for packed
gau-nernst May 16, 2024
f542eb1
some rename
gau-nernst May 16, 2024
40dc725
update name
gau-nernst May 16, 2024
3a98874
Merge branch 'main' into fp6_quant
gau-nernst May 17, 2024
eef2f95
add OpenMP
gau-nernst May 17, 2024
f61aa37
fix CUDA bug
gau-nernst May 17, 2024
1640bbf
add fp6->fp16
gau-nernst May 17, 2024
7a00b31
add FP6->FP32
gau-nernst May 17, 2024
b2fcc6c
move files around
gau-nernst May 17, 2024
d9ca476
rearrange stuff
gau-nernst May 17, 2024
ba89a0b
add more things
gau-nernst May 18, 2024
faf8682
Merge branch 'pytorch:main' into fp6_quant
gau-nernst May 18, 2024
e7b3135
update
gau-nernst May 18, 2024
8b3ac04
update. add comments
gau-nernst May 18, 2024
0635882
some rename. add some tests
gau-nernst May 18, 2024
4240692
add fp32->fp6 unpacked
gau-nernst May 18, 2024
7eb6fa8
fix
gau-nernst May 18, 2024
1c0e401
Merge branch 'main' into fp6_quant
gau-nernst May 18, 2024
26669b6
use template. add BF16
gau-nernst May 19, 2024
e09b61f
use template
gau-nernst May 19, 2024
887bac2
simplify API. add BF16 support via templates
gau-nernst May 19, 2024
4b5c99f
typo
gau-nernst May 19, 2024
39f9dce
enable OpenMP via compile flags
gau-nernst May 19, 2024
b681ae1
add memory access optimized version (though it is not faster..)
gau-nernst May 19, 2024
7c5fcd3
use fp32 mul impl for CUDA
gau-nernst May 19, 2024
82e4e60
add test case
gau-nernst May 19, 2024
fb18c73
typo. remove OpenMP since we cannot throw exception
gau-nernst May 19, 2024
a3c5e36
fix rounding for subnormal
gau-nernst May 19, 2024
27781e5
add to_fp6_value()
gau-nernst May 20, 2024
7d9dd34
simplify to_fp6_unpacked_cuda
gau-nernst May 20, 2024
7fb8c8b
simplify to_fp6_packed_cuda
gau-nernst May 20, 2024
965838c
clean up CPU impl
gau-nernst May 20, 2024
a64421e
add FP6->FP16/BF16
gau-nernst May 20, 2024
a4b7c7a
add dim check
gau-nernst May 20, 2024
3e4c1c1
add qtorch to dev req
gau-nernst May 20, 2024
632af93
handle exception with OpenMP
gau-nernst May 20, 2024
6c6fe83
handle exception in OpenMP
gau-nernst May 20, 2024
cb08b37
add tests
gau-nernst May 20, 2024
9f94030
more tests
gau-nernst May 20, 2024
7b7e823
simplify test
gau-nernst May 20, 2024
7c1ff7d
rename
gau-nernst May 20, 2024
0472b06
add back checks
gau-nernst May 20, 2024
0bda927
update docs
gau-nernst May 20, 2024
a21837c
add pure pytorch impl
gau-nernst May 20, 2024
6101869
add benchmark
gau-nernst May 20, 2024
81d7aeb
Merge branch 'main' into fp6_quant
msaroufim May 20, 2024
4c1da5f
Merge branch 'main' into fp6_quant
gau-nernst May 21, 2024
df7932b
update benchmark script
gau-nernst May 21, 2024
bdbd907
add triton kernel
gau-nernst May 21, 2024
42bf771
remove CUDA kernel
gau-nernst May 21, 2024
f178b01
move to_fp6 to dtypes/
gau-nernst May 21, 2024
ee2310c
add to_fp6 import
gau-nernst May 21, 2024
404f700
move tests
gau-nernst May 21, 2024
48fe45f
update benchmark script
gau-nernst May 21, 2024
5126f8f
add from_fp6
gau-nernst May 21, 2024
da767fa
migrate test
gau-nernst May 21, 2024
0b56ecf
add docs
gau-nernst May 21, 2024
110e888
add docs
gau-nernst May 21, 2024
3e2643c
add torch.compile test
gau-nernst May 21, 2024
71ecc45
Merge branch 'main' into fp6_quant
msaroufim May 21, 2024
750fbc6
polish docs
gau-nernst May 21, 2024
6a3f0c0
remove original weight dequant
gau-nernst May 21, 2024
f32d09f
remove weight dequant
gau-nernst May 21, 2024
8b5b81e
improve tests
gau-nernst May 22, 2024
a3cf93b
update names
gau-nernst May 22, 2024
3c636ff
rename
gau-nernst May 22, 2024
f672c70
update names
gau-nernst May 22, 2024
1a310e3
add notes about denormal numbers
gau-nernst May 23, 2024
c9ec255
update note
gau-nernst May 23, 2024
d1697e7
Merge branch 'main' into fp6_quant
gau-nernst May 25, 2024
8c86028
Merge branch 'main' into fp6_quant
gau-nernst May 25, 2024
d24dba8
fix merge problem
gau-nernst May 25, 2024
ce5dac1
fix merge conflict
gau-nernst May 25, 2024
922446d
add to_fp6 CPU C++ kernel
gau-nernst May 25, 2024
d287eb3
add from_fp6 cpu C++
gau-nernst May 25, 2024
ce7e09a
rename
gau-nernst May 25, 2024
22007a1
add some comments
gau-nernst May 25, 2024
f97421a
small cleanup
gau-nernst May 25, 2024
f727de0
always use uint32_t for bit manipulation
gau-nernst May 25, 2024
78e79ac
simplify test
gau-nernst May 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ pandas

# Custom CUDA Extensions
ninja

# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
qtorch
2 changes: 2 additions & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ torchao.dtypes

to_nf4
UInt4Tensor
to_float6_e3m2
from_float6_e3m2

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_link_args = ["-fopenmp"]
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
127 changes: 127 additions & 0 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2


_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6(TestCase):

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize(
"input_output",
[
(0.0, 0b000000), # exact values
(1.0, 0b001100), # normal numbers
(1.25, 0b001101),
(28.0, 0b011111), # max
(0.1875, 0b000011), # subnormal number
(0.0625, 0b000001), # min
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
],
)
def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=dtype)
assert to_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype):
x = torch.randn(128, 128, device=device, dtype=dtype)
results_unpacked = to_float6_e3m2(x, no_bit_packing=True)
results_packed = to_float6_e3m2(x)

val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1)
bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011
bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222
bits2 = (val2 << 6) | (val3); # 2233 3333

expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2)
assert (results_packed == expected_packed).all()

@parametrize("device", _DEVICES)
@parametrize("shape", [(), (0,), (10,), (20, 20)])
def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x, no_bit_packing=True)
assert result.shape == shape

@parametrize("device", _DEVICES)
@parametrize("shape", [(4,), (20, 20)])
def test_to_float6_e3m2_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x)
assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,)

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize("no_bit_packing", [False, True])
def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing):
x = torch.randn(20, 20, device=device, dtype=dtype)
expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing)

to_float6_e3m2_compiled = torch.compile(to_float6_e3m2)
actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize(
"input_output",
[
(0b000000, 0.0),
(0b001100, 1.0),
(0b011111, 28.0),
(0b000001, 0.0625),
(0b001110, 1.5),
(0b000011, 0.1875),
],
)
def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=torch.uint8)
assert from_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
def test_from_float6_e3m2_bit_packing_correctness(self, device):
x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8)
actual = from_float6_e3m2(x)

bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1)
x_unpacked0 = bits0 >> 2
x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4)
x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6)
x_unpacked3 = bits2 & 0x3F

x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2)
expected = from_float6_e3m2(x_unpacked, no_bit_packing=True)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize("no_bit_packing", [False, True])
def test_from_float6_e3m2_compile(self, device, no_bit_packing):
x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8)
expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing)

from_float6_e3m2_compiled = torch.compile(from_float6_e3m2)
actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFp6)


if __name__ == "__main__":
run_tests()
33 changes: 9 additions & 24 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,21 @@ def test_prepack_fp6_weight(self):
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6(self):
def test_fp16_to_fp6_original(self):
OC = 256
IC = 256

# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
# also, we don't have nan/inf
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0

# the original FP16->FP6 kernel checks for overflow/underflow
fp16_weight.clip_(-28.0, 28.0)
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0

# smoke test
torchao.ops.fp16_to_fp6(fp16_weight)
torchao.ops.fp16_to_fp6_original(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
Expand All @@ -89,19 +86,6 @@ def test_fp16act_fp6weight_linear(self):
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_weight_dequant(self):
OC = 256
IC = 256
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand All @@ -115,7 +99,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
fp32_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8)) * fp16_scale[:, None]
fp16_weight = fp32_weight.half().cuda()
results_fp16 = act_cuda @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
Expand Down
13 changes: 7 additions & 6 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes
import torch
_IS_FBCODE = (
hasattr(torch._utils_internal, "IS_FBSOURCE") and
Expand All @@ -14,6 +8,13 @@
from . import _C
from . import ops
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to import _C first since to/from_float6_e3m2() (from dtypes) calls C++ extension for CPU.


from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes

__all__ = [
"dtypes",
"apply_dynamic_quant",
Expand Down
69 changes: 2 additions & 67 deletions torchao/csrc/cuda/fp6_llm/weight_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.
//
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h
// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h

#include <cuda_fp16.h>
#include <iostream>
Expand Down Expand Up @@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,
}
}

void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with from_float6_e3m2()

assert(M%64==0); // Currently, M must be a multiple of 64.
assert(K%64==0); // Currently, K must be a multiple of 64.
size_t TotalSizeInByte = M*K*6/8;
//
half* OutPTR = A_16bit_h;
for(size_t i=0; i<TotalSizeInByte/3; i++) { // 4 FP6 = 3 Bytes for each Loop
unsigned char B1 = A_6bit_h[i*3+0] & 0xfc;
B1 = (B1&0x80) | ((B1>>2)&0x1f);
unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc);
B2 = (B2&0x80) | ((B2>>2)&0x1f);
unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc);
B3 = (B3&0x80) | ((B3>>2)&0x1f);
unsigned char B4 = A_6bit_h[i*3+2]<<2;
B4 = (B4&0x80) | ((B4>>2)&0x1f);
half FP1, FP2, FP3, FP4;
unsigned char *PTR1, *PTR2, *PTR3, *PTR4;
PTR1 = reinterpret_cast<unsigned char*>(&FP1);
PTR2 = reinterpret_cast<unsigned char*>(&FP2);
PTR3 = reinterpret_cast<unsigned char*>(&FP3);
PTR4 = reinterpret_cast<unsigned char*>(&FP4);
PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU
PTR2[0] = 0; PTR2[1] = B2;
PTR3[0] = 0; PTR3[1] = B3;
PTR4[0] = 0; PTR4[1] = B4;
OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) );
//
OutPTR +=4;
}
}


#include <torch/extension.h>
#include <ATen/ATen.h>
#include <torch/library.h>

namespace torchao {

// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194
at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor)
{
TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional");
TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16");
Expand All @@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
return packed_fp6_tensor;
}

/*
* Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
* A useful tool to construct input matrices for the FP16 GEMM baseline.
* [Input]
* fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights.
* fp16_scale: half tensor of shape [OC]; // for row-wise quantization.
* [Output]
* fp16_tensor: half tensor of shape [OC, IC].
*/
at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale)
{
int OC = fp6_tensor.size(0);
TORCH_CHECK(fp6_tensor.size(1) % 3 == 0);
int IC = fp6_tensor.size(1) / 3 * 16;
TORCH_CHECK(fp16_scale.size(0) == OC);
//
auto fp6_tensor_ptr = reinterpret_cast<unsigned char*>(fp6_tensor.data_ptr<int>());
auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());
//
auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device());
at::Tensor fp16_tensor = at::empty({OC, IC}, options);
auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());
//
DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr);
//
return fp16_tensor;
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu);
m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu);
m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu);
}

}
Loading
Loading