From 7fabc8fdd709fbac41829fb8bdecc1a3c15e86de Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 31 May 2024 07:27:14 +0800 Subject: [PATCH 01/19] override load from state dict --- torchao/quantization/fp6_llm.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 0fb0f7dd9..b9b87559f 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -287,6 +287,22 @@ def from_float(cls, linear: nn.Linear): bias = linear.bias.detach().half() if linear.bias is not None else None return cls(tc_fp6_weight, scales.half(), bias) + # without load_state_dict_pre_hook() https://github.com/pytorch/pytorch/issues/75287 + # we have to override this internal method to be able to convert weights to FP6 on the fly. + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if state_dict[f"{prefix}.weight"].shape == (self.out_features, self.in_features): + fp32_weight = state_dict[f"{prefix}.weight"].detach().float() + scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX + scales[scales == 0.0] = 1.0 # avoid 0 scale + + tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) + tc_fp6_weight = tc_fp6_weight.view(self.out_features, -1).view(torch.int32) + + state_dict[f"{prefix}.weight"] = tc_fp6_weight + state_dict[f"{prefix}.scales"] = scales + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + def extra_repr(self) -> str: return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' From 1c085681727117170aa49b038a13e208adee8acd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 31 May 2024 04:57:26 +0000 Subject: [PATCH 02/19] fix prefix --- torchao/quantization/fp6_llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index b9b87559f..56f74287d 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -290,16 +290,16 @@ def from_float(cls, linear: nn.Linear): # without load_state_dict_pre_hook() https://github.com/pytorch/pytorch/issues/75287 # we have to override this internal method to be able to convert weights to FP6 on the fly. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if state_dict[f"{prefix}.weight"].shape == (self.out_features, self.in_features): - fp32_weight = state_dict[f"{prefix}.weight"].detach().float() + if state_dict[f"{prefix}weight"].shape == (self.out_features, self.in_features): + fp32_weight = state_dict[f"{prefix}weight"].detach().float() scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX scales[scales == 0.0] = 1.0 # avoid 0 scale tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) tc_fp6_weight = tc_fp6_weight.view(self.out_features, -1).view(torch.int32) - state_dict[f"{prefix}.weight"] = tc_fp6_weight - state_dict[f"{prefix}.scales"] = scales + state_dict[f"{prefix}weight"] = tc_fp6_weight + state_dict[f"{prefix}scales"] = scales return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) From d89e9da9f455b841565a1771736192893340478b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 1 Jun 2024 23:10:41 +0800 Subject: [PATCH 03/19] migrate to mx primitive --- torchao/quantization/fp6_llm.py | 39 +++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 56f74287d..4eb20f86f 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -3,7 +3,8 @@ import torch from torch import nn, Tensor -from torchao.dtypes.float6_e3m2 import FLOAT6_E3M2_MAX, to_float6_e3m2, from_float6_e3m2 +from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32 +from torchao.prototype.mx_formats.constants import F6_E3M2_MAX from torchao.ops import fp16act_fp6weight_linear @@ -30,7 +31,7 @@ def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) # Pass 1 from original code tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) @@ -75,7 +76,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = to_float6_e3m2(tensor, no_bit_packing=True) + tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor) tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.flip(3) @@ -90,6 +91,13 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: return torch.cat([tensor_2bit, tensor_4bit], dim=0) +def to_scaled_tc_float6_e3m2(tensor: Tensor) -> tuple[Tensor, Tensor]: + scale = F6_E3M2_MAX / tensor.abs().amax(1).clamp(min=1e-12) + tc_fp6_tensor = to_tc_float6_e3m2(tensor * scale.view(-1, 1)) + tc_fp6_tensor = tc_fp6_tensor.view(tensor.shape[0], -1).view(torch.int32) + return tc_fp6_tensor, scale.reciprocal().half() + + def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: assert tensor.ndim == 1 assert (M % 64 == 0) and (N % 64 == 0) @@ -109,7 +117,7 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor tensor_fp6 = (tensor_2bit << 4) | tensor_4bit tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) + return f6_e3m2_unpacked_to_f32(tensor_fp6).to(dtype) # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -271,35 +279,24 @@ def forward(self, x: Tensor) -> Tensor: @staticmethod def get_split_k(bsize: int, out_dim: int) -> int: # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 @classmethod def from_float(cls, linear: nn.Linear): assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) - fp32_weight = linear.weight.detach().float() - scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX - scales[scales == 0.0] = 1.0 # avoid 0 scale - - tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) - tc_fp6_weight = tc_fp6_weight.view(linear.out_features, -1).view(torch.int32) - + fp6_weight, scale = to_scaled_tc_float6_e3m2(linear.weight.detach()) bias = linear.bias.detach().half() if linear.bias is not None else None - return cls(tc_fp6_weight, scales.half(), bias) + return cls(fp6_weight, scale, bias) # without load_state_dict_pre_hook() https://github.com/pytorch/pytorch/issues/75287 # we have to override this internal method to be able to convert weights to FP6 on the fly. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if state_dict[f"{prefix}weight"].shape == (self.out_features, self.in_features): - fp32_weight = state_dict[f"{prefix}weight"].detach().float() - scales = fp32_weight.abs().amax(1) / FLOAT6_E3M2_MAX - scales[scales == 0.0] = 1.0 # avoid 0 scale - - tc_fp6_weight = to_tc_float6_e3m2(fp32_weight / scales.view(-1, 1)) - tc_fp6_weight = tc_fp6_weight.view(self.out_features, -1).view(torch.int32) + fp6_weight, scale = to_scaled_tc_float6_e3m2(state_dict[f"{prefix}weight"]) - state_dict[f"{prefix}weight"] = tc_fp6_weight - state_dict[f"{prefix}scales"] = scales + state_dict[f"{prefix}weight"] = fp6_weight + state_dict[f"{prefix}scales"] = scale return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) From 6f84293d0c39f0797b4b83980625d5e5c0308c34 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 1 Jun 2024 23:18:21 +0800 Subject: [PATCH 04/19] remove unneeded code --- test/prototype/mx_formats/test_custom_cast.py | 2 + torchao/csrc/{ => cuda}/fp6_llm/README.md | 0 torchao/csrc/cuda/fp6_llm/weight_quant.cu | 154 --------- torchao/csrc/fp6_llm.cpp | 8 + torchao/csrc/fp6_llm/float6_e3m2.cpp | 319 ------------------ torchao/csrc/fp6_llm/fp6_llm.cpp | 15 - torchao/csrc/fp6_llm/weight_prepacking.cpp | 220 ------------ torchao/dtypes/__init__.py | 7 - torchao/dtypes/float6_e3m2.py | 178 ---------- torchao/ops.py | 58 ---- 10 files changed, 10 insertions(+), 951 deletions(-) rename torchao/csrc/{ => cuda}/fp6_llm/README.md (100%) delete mode 100644 torchao/csrc/cuda/fp6_llm/weight_quant.cu create mode 100644 torchao/csrc/fp6_llm.cpp delete mode 100644 torchao/csrc/fp6_llm/float6_e3m2.cpp delete mode 100644 torchao/csrc/fp6_llm/fp6_llm.cpp delete mode 100644 torchao/csrc/fp6_llm/weight_prepacking.cpp delete mode 100644 torchao/dtypes/float6_e3m2.py diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 892d5b57f..85d958633 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -386,3 +386,5 @@ def test_fp6_values(dtype_name): else: raise AssertionError("unsupported") torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) + +# TODO: move test/dtypes/test_float6_e3m2.py here diff --git a/torchao/csrc/fp6_llm/README.md b/torchao/csrc/cuda/fp6_llm/README.md similarity index 100% rename from torchao/csrc/fp6_llm/README.md rename to torchao/csrc/cuda/fp6_llm/README.md diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu deleted file mode 100644 index b519cbfb0..000000000 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h - -#include -#include -#include - -/* - * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. - */ -void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) -{ - // Constants for FP6 - constexpr int exponent_nbits_fp6 = 3; - constexpr int mantissa_nbits_fp6 = 2; - constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; - // Constants for FP16 - constexpr int exponent_nbits_fp16 = 5; - constexpr int mantissa_nbits_fp16 = 10; - constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; - - int fp6_temp[4]; - - float absmin_nonzero_fp6 = 0.0625; - // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is - // the same with that in qtorch. - float absmax_fp6 = 28; - - for (int i = 0; i < 4; ++i) { - uint16_t source = FP16x4[i]; - float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); - if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || - fp6_value_abs > absmax_fp6) { - // TODO(zhen): a better way may be rounding it to the nearest FP6 value. - throw std::invalid_argument("Input value out of range for FP6."); - } - - // It is not safe to do shift operation on uint16_t. So we promote it to int. - int source_promote = int(source); - - int sign_bit = (source_promote >> 15); - // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' - int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; - // Extracting mantissa represented in FP16 - int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); - - int new_exp_bit; - int new_mant_bit; - - if (exp_bit == 0) { - // Subnormal FP16 number. Too small for FP6. - new_exp_bit = 0; - new_mant_bit = 0; - } else { - new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); - new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; - - // Deal with subnormal FP6 values. - int target_exp_val = exp_bit - exp_bias_fp16; - int min_fp6_exp_val = -exp_bias_fp6 + 1; - bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; - if (subnormal_fp6) { - // TODO(zhen): add the rounding logic. - new_exp_bit = 0; - // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we - // need to add it - new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> - (min_fp6_exp_val - target_exp_val); - } - } - - fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | - (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; - } - // Pack the values - FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); - FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); - FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; -} - -/* - * Function to prepack FP16 weights into continuous FP6 values. - * - * Parameters: - * weight_16bit: input weight in FP16, size M*K - * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 - * M, K: the shape of the weight - */ -void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, - uint8_t* weight_6bit_packed, - size_t M, - size_t K) -{ - // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). - if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } - size_t K_fp6_packed = K * 6 / 8; - // #pragma omp parallel for - for (auto m = 0; m < M; m++) { - uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; - uint16_t* ptr_16bit = weight_16bit + m * K; - for (auto k = 0; k < K; k += 4) { - cast_fp16_fp6(ptr_16bit, ptr_6bit); - ptr_16bit += 4; - ptr_6bit += 3; - } - } -} - -#include -#include -#include - -namespace torchao { - -// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 -at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor) -{ - TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); - TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); - TORCH_CHECK(fp16_tensor.is_contiguous(), "weight must be contiguous"); - TORCH_CHECK(fp16_tensor.device().type() == torch::kCPU, "weight must be on CPU"); - auto M = fp16_tensor.size(0); - auto K = fp16_tensor.size(1); - TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); - - // Pack weight from FP16 to FP6. - auto options = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto packed_fp6_tensor = at::empty({M, K * 6 / 8}, options); - uint8_t* packed_fp6_ptr = packed_fp6_tensor.data_ptr(); - - uint16_t* fake_fp6_ptr = reinterpret_cast(fp16_tensor.data_ptr()); - weight_prepacking_fp16_to_fp6(fake_fp6_ptr, packed_fp6_ptr, M, K); - - return packed_fp6_tensor; -} - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); -} - -} diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp new file mode 100644 index 000000000..997265546 --- /dev/null +++ b/torchao/csrc/fp6_llm.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); +} diff --git a/torchao/csrc/fp6_llm/float6_e3m2.cpp b/torchao/csrc/fp6_llm/float6_e3m2.cpp deleted file mode 100644 index 16d71f51d..000000000 --- a/torchao/csrc/fp6_llm/float6_e3m2.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include -#include -#include - -#include -#include -#include - - -class float6_e3m2_nan_inf : public std::invalid_argument { -public: - float6_e3m2_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in float6_e3m2.") { } -}; - -class float6_e3m2_overflow : public std::invalid_argument { -public: - float6_e3m2_overflow() : std::invalid_argument("float6_e3m2 overflow. float6_e3m2 cannot represent +/-inf. Make sure input < 30.0") { } -}; - -// we need to do this because C++17 does not allow using struct as template non-type parameter -// use the upper 16 bits for num exponent, lower 16 bits for num mantissa -static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; } -static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); -static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); -static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); - -// NOTE: only works for len < 32 -static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } - -// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -template -static uint8_t to_float6_e3m2_bits(T bits_) { - constexpr uint32_t N_EXP = FP_SPEC >> 16u; - constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); - constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; - constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; - - // sanity checks. will be removed in template instantiation. - // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. - static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); - static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); - - uint32_t bits = bits_; // bit extension - uint32_t sign = bits >> N_EXP_MAN << 5u; - bits &= ones_mask(N_EXP_MAN); // clear sign bit - uint32_t result, remainder; - - // all exponent bits are 1s - if (bits >= (ones_mask(N_EXP) << N_MAN)) throw float6_e3m2_nan_inf(); - - // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw float6_e3m2_overflow(); - - // FP6 normal number (E>=001) - if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { - remainder = bits << (32u - (N_MAN - 2u)); // shift the truncated bits to most significant position - bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent - result = sign | (bits >> (N_MAN - 2u)); - } - // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) - else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { - uint32_t exp = bits >> N_MAN; - uint32_t man = bits & ones_mask(N_MAN); - - // to make subnormal FP6 from normal FP16 - // step 1: add implicit 1 to mantissa - man |= (1u << N_MAN); - - // step 2: shift mantissa right so that exponent value is equal to - // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) - uint32_t shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (32u - (N_MAN - 2u + shift)); // shift the truncated bits to most significant position - result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 - } - // FP6 underflow. E=000, M=00 - else { - remainder = 0u; - result = sign; - } - - // round to nearest even - if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 0x1u))) { - result += 1; - } - return result; -} - -// assume the lower 6 bits contain the data. -template -static T from_float6_e3m2_bits(uint8_t a) { - constexpr uint32_t N_EXP = FP_SPEC >> 16u; - constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); - constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; - constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; - - uint32_t bits = a; // bit extension - uint32_t sign = bits >> 5u; - uint32_t exp = (bits >> 2u) & 0x7u; - uint32_t man = bits & 0x3u; - - if (exp > 0u) { // FP6 normal numbers - exp += EXP_BIAS_DIFF; - } else if (man > 0u) { // FP6 denormal numbers - uint32_t shift = (man >= 0b10u) ? 1u : 2u; - man = (man << shift) & 0x3u; // shift and remove explicit 1 - exp = 1u + EXP_BIAS_DIFF - shift; - } - // don't need to handle zero, since E=000 and M=00 - - uint32_t result = (sign << N_EXP_MAN) | (exp << N_MAN) | (man << (N_MAN - 2u)); - return static_cast(result); -} - -namespace torchao { - -template void to_float6_e3m2_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - // exception within OpenMP parallel region must be caught. - // set a flag when exception occurs, then re-raise it. - bool found_nan_inf = false; - bool found_overflow = false; - -#pragma omp parallel for - for (int i = 0; i < n; i++) { - try { fp6_ptr[i] = to_float6_e3m2_bits(bits_ptr[i]); } - catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } - catch (float6_e3m2_overflow const &) { found_overflow = true; } - } - - if (found_nan_inf) throw float6_e3m2_nan_inf(); - if (found_overflow) throw float6_e3m2_overflow(); -} - -// this is useful for debugging -at::Tensor to_float6_e3m2_unpacked_cpu(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cpu()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -template void to_float6_e3m2_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - // exception within OpenMP parallel region must be caught. - // set a flag when exception occurs, then re-raise it. - bool found_nan_inf = false; - bool found_overflow = false; - -#pragma omp parallel for - for (int i = 0; i < n / 4; i++) { - try { - uint8_t val0 = to_float6_e3m2_bits(bits_ptr[i * 4]); - uint8_t val1 = to_float6_e3m2_bits(bits_ptr[i * 4 + 1]); - uint8_t val2 = to_float6_e3m2_bits(bits_ptr[i * 4 + 2]); - uint8_t val3 = to_float6_e3m2_bits(bits_ptr[i * 4 + 3]); - - fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } - catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } - catch (float6_e3m2_overflow const &) { found_overflow = true; } - } - - if (found_nan_inf) throw float6_e3m2_nan_inf(); - if (found_overflow) throw float6_e3m2_overflow(); -} - -at::Tensor to_float6_e3m2_packed_cpu(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cpu()); - TORCH_CHECK(fp_tensor.ndimension() == 2); - - int M = fp_tensor.size(0); - int N = fp_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_packed_cpu_impl(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_packed_cpu_impl(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_float6_e3m2_packed_cpu_impl(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -template -void from_float6_e3m2_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { -#pragma omp parallel for - for (int i = 0; i < n; i++) - fp_ptr[i] = from_float6_e3m2_bits(fp6_ptr[i]); -} - -at::Tensor from_float6_e3m2_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu()); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (dtype == torch::kFloat32) { - uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp32_ptr, n); - - } else if (dtype == torch::kFloat16) { - uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp16_ptr, n); - - } else if (dtype == torch::kBFloat16) { - uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, bf16_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp_tensor; -} - -template -void from_float6_e3m2_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { -#pragma omp parallel for - for (int i = 0; i < n / 3; i++) { - uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 - uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 - - fp_ptr[i * 4] = from_float6_e3m2_bits(bits0 >> 2); - fp_ptr[i * 4 + 1] = from_float6_e3m2_bits(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp_ptr[i * 4 + 2] = from_float6_e3m2_bits(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp_ptr[i * 4 + 3] = from_float6_e3m2_bits(bits2 & 0x3Fu); - } -} - -at::Tensor from_float6_e3m2_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu()); - TORCH_CHECK(fp6_tensor.ndimension() == 2); - - int M = fp6_tensor.size(0); - int N = fp6_tensor.size(1); - TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (dtype == torch::kFloat32) { - uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp32_ptr, n); - - } else if (dtype == torch::kFloat16) { - uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp16_ptr, n); - - } else if (dtype == torch::kBFloat16) { - uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_float6_e3m2_packed_cpu_impl(fp6_ptr, bf16_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp_tensor; -} - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::to_float6_e3m2_unpacked_cpu", &to_float6_e3m2_unpacked_cpu); - m.impl("torchao::to_float6_e3m2_packed_cpu", &to_float6_e3m2_packed_cpu); - m.impl("torchao::from_float6_e3m2_unpacked_cpu", &from_float6_e3m2_unpacked_cpu); - m.impl("torchao::from_float6_e3m2_packed_cpu", &from_float6_e3m2_packed_cpu); -} - -} diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp deleted file mode 100644 index 5239593bb..000000000 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include -#include -#include - -TORCH_LIBRARY_FRAGMENT(torchao, m) { - m.impl_abstract_pystub("torchao.ops"); - m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); - m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); - m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); - - m.def("to_float6_e3m2_unpacked_cpu(Tensor tensor) -> Tensor"); - m.def("to_float6_e3m2_packed_cpu(Tensor tensor) -> Tensor"); - m.def("from_float6_e3m2_unpacked_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); - m.def("from_float6_e3m2_packed_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); -} diff --git a/torchao/csrc/fp6_llm/weight_prepacking.cpp b/torchao/csrc/fp6_llm/weight_prepacking.cpp deleted file mode 100644 index 89a1171f5..000000000 --- a/torchao/csrc/fp6_llm/weight_prepacking.cpp +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h - -#include -#include -#include - -using namespace std; - -void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], unsigned char* FP6_Array) // padding 0 to the lowerest bit location -{ - Padded_FP6[0] = FP6_Array[0] & 0xfc; - Padded_FP6[1] = (FP6_Array[0]<<6) | ((FP6_Array[1]>>2) & 0xfc); - Padded_FP6[2] = (FP6_Array[1]<<4) | ((FP6_Array[2]>>4) & 0xfc ); - Padded_FP6[3] = FP6_Array[2]<<2; - Padded_FP6[4] = FP6_Array[3] & 0xfc; - Padded_FP6[5] = (FP6_Array[3]<<6) | ((FP6_Array[4]>>2) & 0xfc); - Padded_FP6[6] = (FP6_Array[4]<<4) | ((FP6_Array[5]>>4) & 0xfc); - Padded_FP6[7] = FP6_Array[5]<<2; -} - -unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, unsigned char B2, unsigned char B3, unsigned char B4) -{ - unsigned char out; - out = (B1&0xc0) | ( (B2&0xc0) >> 2 ) | ( (B3&0xc0) >> 4 ) | ( (B4&0xc0) >> 6 ); - return out; -} - -unsigned char Extract_4_Bits_From_2_PaddedFP6(unsigned char B1, unsigned char B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); -{ - unsigned char out; - out = ( (B1<<2) & 0xf0 ) | ( (B2>>2) & 0x0f ); - return out; -} - -// dealing with 4 1*8 blocks of FP6 -void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], vector Seg_4bit[], unsigned char* PTR_1, unsigned char* PTR_2, unsigned char* PTR_3, unsigned char* PTR_4) -{ - unsigned char Padded_8_FP8[4][8]; - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); - // - unsigned char Seg1_Byte1_T[4]; - unsigned char Seg1_Byte2_T[4]; - unsigned char Seg2_Byte1_T[4]; - unsigned char Seg2_Byte2_T[4]; - unsigned char Seg2_Byte3_T[4]; - unsigned char Seg2_Byte4_T[4]; - for(int t=0; t<4; t++) - { - Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2], Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); - Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2], Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); - Seg2_Byte1_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0+t*2], Padded_8_FP8[0][1+t*2]); - Seg2_Byte2_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0+t*2], Padded_8_FP8[1][1+t*2]); - Seg2_Byte3_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0+t*2], Padded_8_FP8[2][1+t*2]); - Seg2_Byte4_T[t] = Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0+t*2], Padded_8_FP8[3][1+t*2]); - } - // - for(int t=0; t<4; t++) - { - Seg_2bit[t].push_back(Seg1_Byte1_T[t]); - Seg_2bit[t].push_back(Seg1_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte1_T[t]); - Seg_4bit[t].push_back(Seg2_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte3_T[t]); - Seg_4bit[t].push_back(Seg2_Byte4_T[t]); - } - return; -} - -void BitInterleaving_2bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - //int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for bit-interleaving in FP6-LLM - int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. - for(int i=0; i<16; i++) - Frags_2bit[i] = ( input << 2*(order_2bit[i]-1) ) & 0xc0000000; - // - unsigned int output = 0x00000000; - for(int i=0; i<16; i++) - output |= ( Frags_2bit[i] >> (i*2) ); - // - *PTR_UINT = output; -} - -void BitInterleaving_4bit(unsigned char* PTR_4Bytes) -{ - unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - //int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in FP6-LLM - int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM - unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. - for(int i=0; i<8; i++) - Frags_4bit[i] = ( input << 4*(order_4bit[i]-1) ) & 0xf0000000; - // - unsigned int output = 0x00000000; - for(int i=0; i<8; i++) - output |= ( Frags_4bit[i] >> (i*4) ); - // - *PTR_UINT = output; -} - -/* - * Inputs: - * (1) unsigned char Weight_6bit [M*K*6/8] - * Outputs: - * (1) unsigned char Weight_2bit [M*K*2/8] - * (2) unsigned char Weight_4bit [M*K*4/8] - * - * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. - * 8 FP6 = 6 Bytes - * 8 FP4 = 4 Bytes - * 8 FP2 = 2 Bytes - */ -void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K) -{ - assert(M % 64 == 0); - assert(K % 64 == 0); - // - unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); - unsigned char* Weight_2bit = reinterpret_cast(packed_weights); - unsigned char* Weight_4bit = Weight_2bit + M*K*2/8; - // - vector A_Segment_2bit[32]; - vector A_Segment_4bit[32]; - // - size_t BytesPerRow = K*6/8; - // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. - for (size_t i = 0; i < M / 64; i++) // - { - for (size_t j = 0; j < K / 16; j++) - { - for(size_t k=0; k<64/16; k++) - { - size_t row = i*64 + k*16; - size_t col = j*16; - unsigned char* StartPTR_1 = Weight_6bit + row*BytesPerRow + col*6/8; - unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; - unsigned char* StartPTR_3 = StartPTR_1 + 8*6/8; - unsigned char* StartPTR_4 = StartPTR_2 + 8*6/8; - // Dealing with each 16*16 blocks then... - for(int l=0; l<8; l++) Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l*4], &A_Segment_4bit[l*4], StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow); - } - } - } - // Verifying the length of 2_bit segments and 4_bit segments - size_t BytesPerThread_2bit = M*K*2/8/32; - size_t BytesPerThread_4bit = M*K*4/8/32; - for(int i=0; i<32; i++) - { - assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); - assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); - } - // Pass-2: Optimizing coleasced global memory access - for(size_t i=0; i -#include - -namespace torchao { - -/* - * Weight prepacking (Pytorch interface). - * [Input & Output] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * [Output] - * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; - */ -at::Tensor weight_matrix_prepacking_cpu(at::Tensor fp6_tensor) -{ - size_t OC = fp6_tensor.size(0); - size_t IC = fp6_tensor.size(1); - TORCH_CHECK(IC % 3 == 0, "Expect packed input dim % 3 == 0, but receive ", IC, " instead."); - IC = IC * 16 / 3; - TORCH_CHECK((OC % 256 == 0) && (IC % 64 == 0), "Expect output dim % 256 == 0 and input dim % 64 == 0, but receive ", OC, " and ", IC, " instead."); - auto packed_tensor = at::empty_like(fp6_tensor); - auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); - return packed_tensor; -} - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::prepack_fp6_weight", &weight_matrix_prepacking_cpu); -} - -} diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 44077dab6..dccd22f3d 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,10 +9,3 @@ "AffineQuantizedTensor", "to_aq", ] - -# CPP extensions -try: - from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2 - __all__.extend(["to_float6_e3m2", "from_float6_e3m2"]) -except RuntimeError: - pass diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py deleted file mode 100644 index 0c27838d0..000000000 --- a/torchao/dtypes/float6_e3m2.py +++ /dev/null @@ -1,178 +0,0 @@ -import torch -from torch import Tensor -from torch.utils._triton import has_triton -from torchao.ops import to_float6_e3m2_packed_cpu, to_float6_e3m2_unpacked_cpu, from_float6_e3m2_packed_cpu, from_float6_e3m2_unpacked_cpu - - -# some useful constants -FLOAT6_E3M2_MAX = 28.0 -FLOAT6_E3M2_SMALLEST_SUBNORMAL = 0.0625 - - -if has_triton(): - import triton - from triton import language as tl - - # see _to_float6_e3m2_pt() for explanation - @triton.jit - def _triton_float32_to_float6_e3m2(x: tl.tensor): - x = x.to(tl.float32) - x = x * 2.0 ** (-127 + 3) - bits = x.to(tl.int32, bitcast=True) - - sign = ((bits >> 31) & 0x1) << 5 - exp_and_man = (bits >> 21) & 0x1F - result = sign | exp_and_man - - remainder = bits & 0x1F_FFFF - do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) - result = tl.where(do_round_up, result + 1, result) - return result.to(tl.uint8) - - @triton.jit - def _to_float6_e3m2_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n - - # strided memory read. there will be uncoalesced memory access - val0 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4, mask)) - val1 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 1, mask)) - val2 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 2, mask)) - val3 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 3, mask)) - - # bit packing - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - - # strided memory write. there will be uncoalesced memory access - tl.store(out_ptr + offsets * 3, bits0, mask) - tl.store(out_ptr + offsets * 3 + 1, bits1, mask) - tl.store(out_ptr + offsets * 3 + 2, bits2, mask) - - def _to_float6_e3m2_triton(tensor: Tensor) -> Tensor: - out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) - output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) - - n = tensor.numel() - grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"] * 4),) - _to_float6_e3m2_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) - - return output - -else: - _to_float6_e3m2_triton = None - - -# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. -# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). -def _to_float6_e3m2_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: - tensor = tensor.float() - - # correct exponent bias. this also handles subnormal numbers correctly - tensor = tensor * 2.0 ** (-127 + 3) - bits = tensor.view(torch.int32) - - sign = ((bits >> 31) & 0x1) << 5 - exp_and_man = (bits >> 21) & 0x1F - result = sign | exp_and_man - - # round to nearest even - remainder = bits & 0x1F_FFFF # truncated mantissa bits - do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) - result = torch.where(do_round_up, result + 1, result) - result = result.to(torch.uint8) - - if no_bit_packing: - return result - - # bit packing - val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) - - -def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: - """Convert input tensor to FP6. This particular FP6 format has 3 exponent bits and 2 mantissa - bits. By default, bit packing is performed: every 4 FP6 values are packed as 3 uint8 values - (4 x 6 bits = 3 x 8 bits). - - Args: - tensor: Input tensor. The last dimension must be divisible by 4 (unless ``no_bit_packing=False``) - no_bit_packing: Whether to not perform bit packing. Setting this to ``True`` can be useful for - observing the bit patterns and debugging. - - Returns: - :class:`torch.Tensor`: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last - dimension of output tensor is 3/4 of that of input tensor. - - Note: - This FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does - not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. - All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - - See also :func:`from_float6_e3m2` - """ - if not no_bit_packing: - assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" - - if tensor.is_cpu: - if no_bit_packing: - return to_float6_e3m2_unpacked_cpu(tensor) - - *leading_dims, last_dim = tensor.shape - return to_float6_e3m2_packed_cpu(tensor.view(-1, last_dim)).view(*leading_dims, -1) - - # torch.compile() cannot generate fused bit-packing triton kernel, - # thus we write custom triton kernel for this specific case. - if tensor.is_cuda and not no_bit_packing and _to_float6_e3m2_triton is not None: - return _to_float6_e3m2_triton(tensor) - - else: - return _to_float6_e3m2_pt(tensor, no_bit_packing=no_bit_packing) - - -# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. -# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). -def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: - bits = tensor.to(torch.int32) # bit extension - sign = bits >> 5 << 31 - exp_and_man = (bits & 0x1F) << 21 - results = sign | exp_and_man - - results = results.view(torch.float32) - return results * 2.0 ** (127 - 3) # exponent bias correction - - -def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch.dtype = torch.float32) -> Tensor: - """Convert an FP6 tensor (created by :func:`to_float6_e3m2`) to FP32. - - Args: - tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must - be divisible by 3. - no_bit_packing: whether the input does not have bit packing. - dtype: returned dtype. - - Returns: - :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output - tensor is 4/3 of that of input tensor. - """ - assert tensor.dtype == torch.uint8 - if no_bit_packing: - if tensor.is_cpu: - return from_float6_e3m2_unpacked_cpu(tensor, dtype) - - return _pt_float6_e3m2_to_float32(tensor).to(dtype) - - assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" - if tensor.is_cpu: - return from_float6_e3m2_packed_cpu(tensor, dtype) - - bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) - val0 = _pt_float6_e3m2_to_float32(bits0 >> 2).to(dtype) - val1 = _pt_float6_e3m2_to_float32(((bits0 & 0x3) << 4) | (bits1 >> 4)).to(dtype) - val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)).to(dtype) - val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F).to(dtype) - return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) diff --git a/torchao/ops.py b/torchao/ops.py index 7fce2de22..2a603716e 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -12,48 +12,6 @@ def decorator(func): return decorator -def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: - """ - Pack FP6 tensor in a layout for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. - - Arguments - fp6_weight: tightly-packed fp6_weight, inside a `torch.int32` container - - Returns - packed FP6 tensor for use with FP6-LLM, inside a `torch.int32` container - """ - return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) - - -# Defines the meta kernel / fake kernel / abstract impl -@register_custom_op("torchao::prepack_fp6_weight") -def _(fp6_weight): - torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") - return torch.empty_like(fp6_weight) - - -def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: - """ - Pack FP16 tensor to FP6 tensor. qtorch is required to use this function. - """ - try: - from qtorch.quant import float_quantize - except ImportError as e: - raise RuntimeError("Please install qtorch to use this function") from e - - fp16_tensor = float_quantize(fp16_tensor.float(), 3, 2, rounding="nearest").half() - return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) - - -@register_custom_op("torchao::fp16_to_fp6_original") -def _(fp16_tensor): - torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") - torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") - M, K = fp16_tensor.shape - torch._check(K % 4 == 0, lambda: f"second dimension must be a multiple of 4, got {K}") - return torch.empty((M, K * 6 // 8), dtype=torch.uint8, device=fp16_tensor.device) - - def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: """ FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. @@ -85,19 +43,3 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) - - -def to_float6_e3m2_unpacked_cpu(tensor: Tensor) -> Tensor: - return torch.ops.torchao.to_float6_e3m2_unpacked_cpu.default(tensor) - - -def to_float6_e3m2_packed_cpu(tensor: Tensor) -> Tensor: - return torch.ops.torchao.to_float6_e3m2_packed_cpu.default(tensor) - - -def from_float6_e3m2_unpacked_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: - return torch.ops.torchao.from_float6_e3m2_unpacked_cpu.default(tensor, dtype) - - -def from_float6_e3m2_packed_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: - return torch.ops.torchao.from_float6_e3m2_packed_cpu.default(tensor, dtype) From 571910b8d58d3e05220af5c13e52739bb834beff Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 1 Jun 2024 23:18:54 +0800 Subject: [PATCH 05/19] comment out test --- test/dtypes/test_float6_e3m2.py | 268 ++++++++++++++++---------------- 1 file changed, 134 insertions(+), 134 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index 304a78c56..6f37f0a1a 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -1,134 +1,134 @@ -import torch -from torch.testing._internal.common_utils import ( - TestCase, - instantiate_parametrized_tests, - parametrize, - run_tests, -) - -try: - import torchao.ops -except RuntimeError: - pytest.skip("torchao.ops not available") - - -from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 - - -_DTYPES = [torch.float32, torch.float16, torch.bfloat16] -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - - -class TestFloat6E3M2(TestCase): - - @parametrize("device", _DEVICES) - @parametrize("dtype", _DTYPES) - @parametrize( - "input_output", - [ - (0.0, 0b000000), # exact values - (1.0, 0b001100), # normal numbers - (1.25, 0b001101), - (28.0, 0b011111), # max - (0.1875, 0b000011), # subnormal number - (0.0625, 0b000001), # min - (29.0, 0b011111), # normal round down - (26.0, 0b011110), # normal round to nearest even - (0.1251, 0b000010), # subnormal round down - (0.0314, 0b000001), # subnormal round up - (0.03, 0b000000), # underflow - ], - ) - def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): - input, output = input_output - input = torch.tensor(input, device=device, dtype=dtype) - assert to_float6_e3m2(input, no_bit_packing=True).item() == output - - @parametrize("device", _DEVICES) - @parametrize("dtype", _DTYPES) - def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): - x = torch.randn(128, 128, device=device, dtype=dtype) - results_unpacked = to_float6_e3m2(x, no_bit_packing=True) - results_packed = to_float6_e3m2(x) - - val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - - expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) - assert (results_packed == expected_packed).all() - - @parametrize("device", _DEVICES) - @parametrize("shape", [(), (0,), (10,), (20, 20)]) - def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): - x = torch.randn(shape, device=device) - result = to_float6_e3m2(x, no_bit_packing=True) - assert result.shape == shape - - @parametrize("device", _DEVICES) - @parametrize("shape", [(4,), (20, 20)]) - def test_to_float6_e3m2_bit_packing_shape(self, device, shape): - x = torch.randn(shape, device=device) - result = to_float6_e3m2(x) - assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) - - @parametrize("device", _DEVICES) - @parametrize("dtype", _DTYPES) - @parametrize("no_bit_packing", [False, True]) - def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): - x = torch.randn(20, 20, device=device, dtype=dtype) - expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) - - to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) - actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) - torch.testing.assert_close(actual, expected) - - @parametrize("device", _DEVICES) - @parametrize( - "input_output", - [ - (0b000000, 0.0), - (0b001100, 1.0), - (0b011111, 28.0), # max - (0b000001, 0.0625), # min - (0b001110, 1.5), - (0b000011, 0.1875), # subnormal - ], - ) - def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): - input, output = input_output - input = torch.tensor(input, device=device, dtype=torch.uint8) - assert from_float6_e3m2(input, no_bit_packing=True).item() == output - - @parametrize("device", _DEVICES) - def test_from_float6_e3m2_bit_packing_correctness(self, device): - x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) - actual = from_float6_e3m2(x) - - bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) - x_unpacked0 = bits0 >> 2 - x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) - x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) - x_unpacked3 = bits2 & 0x3F - - x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) - expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) - torch.testing.assert_close(actual, expected) - - @parametrize("device", _DEVICES) - @parametrize("no_bit_packing", [False, True]) - def test_from_float6_e3m2_compile(self, device, no_bit_packing): - x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) - expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) - - from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) - actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) - torch.testing.assert_close(actual, expected) - - -instantiate_parametrized_tests(TestFloat6E3M2) - - -if __name__ == "__main__": - run_tests() +# import torch +# from torch.testing._internal.common_utils import ( +# TestCase, +# instantiate_parametrized_tests, +# parametrize, +# run_tests, +# ) + +# try: +# import torchao.ops +# except RuntimeError: +# pytest.skip("torchao.ops not available") + + +# from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 + + +# _DTYPES = [torch.float32, torch.float16, torch.bfloat16] +# _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +# class TestFloat6E3M2(TestCase): + +# @parametrize("device", _DEVICES) +# @parametrize("dtype", _DTYPES) +# @parametrize( +# "input_output", +# [ +# (0.0, 0b000000), # exact values +# (1.0, 0b001100), # normal numbers +# (1.25, 0b001101), +# (28.0, 0b011111), # max +# (0.1875, 0b000011), # subnormal number +# (0.0625, 0b000001), # min +# (29.0, 0b011111), # normal round down +# (26.0, 0b011110), # normal round to nearest even +# (0.1251, 0b000010), # subnormal round down +# (0.0314, 0b000001), # subnormal round up +# (0.03, 0b000000), # underflow +# ], +# ) +# def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): +# input, output = input_output +# input = torch.tensor(input, device=device, dtype=dtype) +# assert to_float6_e3m2(input, no_bit_packing=True).item() == output + +# @parametrize("device", _DEVICES) +# @parametrize("dtype", _DTYPES) +# def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): +# x = torch.randn(128, 128, device=device, dtype=dtype) +# results_unpacked = to_float6_e3m2(x, no_bit_packing=True) +# results_packed = to_float6_e3m2(x) + +# val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) +# bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 +# bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 +# bits2 = (val2 << 6) | (val3); # 2233 3333 + +# expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) +# assert (results_packed == expected_packed).all() + +# @parametrize("device", _DEVICES) +# @parametrize("shape", [(), (0,), (10,), (20, 20)]) +# def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): +# x = torch.randn(shape, device=device) +# result = to_float6_e3m2(x, no_bit_packing=True) +# assert result.shape == shape + +# @parametrize("device", _DEVICES) +# @parametrize("shape", [(4,), (20, 20)]) +# def test_to_float6_e3m2_bit_packing_shape(self, device, shape): +# x = torch.randn(shape, device=device) +# result = to_float6_e3m2(x) +# assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) + +# @parametrize("device", _DEVICES) +# @parametrize("dtype", _DTYPES) +# @parametrize("no_bit_packing", [False, True]) +# def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): +# x = torch.randn(20, 20, device=device, dtype=dtype) +# expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) + +# to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) +# actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) +# torch.testing.assert_close(actual, expected) + +# @parametrize("device", _DEVICES) +# @parametrize( +# "input_output", +# [ +# (0b000000, 0.0), +# (0b001100, 1.0), +# (0b011111, 28.0), # max +# (0b000001, 0.0625), # min +# (0b001110, 1.5), +# (0b000011, 0.1875), # subnormal +# ], +# ) +# def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): +# input, output = input_output +# input = torch.tensor(input, device=device, dtype=torch.uint8) +# assert from_float6_e3m2(input, no_bit_packing=True).item() == output + +# @parametrize("device", _DEVICES) +# def test_from_float6_e3m2_bit_packing_correctness(self, device): +# x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) +# actual = from_float6_e3m2(x) + +# bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) +# x_unpacked0 = bits0 >> 2 +# x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) +# x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) +# x_unpacked3 = bits2 & 0x3F + +# x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) +# expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) +# torch.testing.assert_close(actual, expected) + +# @parametrize("device", _DEVICES) +# @parametrize("no_bit_packing", [False, True]) +# def test_from_float6_e3m2_compile(self, device, no_bit_packing): +# x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) +# expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) + +# from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) +# actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) +# torch.testing.assert_close(actual, expected) + + +# instantiate_parametrized_tests(TestFloat6E3M2) + + +# if __name__ == "__main__": +# run_tests() From 4e2964a2712190d5c2f49bf3133b38a617950754 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Jun 2024 10:15:03 +0800 Subject: [PATCH 06/19] remove --- docs/source/api_ref_dtypes.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 36c3c9b4e..4cb797beb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -12,8 +12,6 @@ torchao.dtypes to_nf4 UInt4Tensor - to_float6_e3m2 - from_float6_e3m2 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring From adefee8f9987770b27f2c74d1bcb3771ffd59e7f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Jun 2024 10:26:17 +0800 Subject: [PATCH 07/19] add rounding test for f6_e3m2 --- test/dtypes/test_float6_e3m2.py | 134 ------------------ test/prototype/mx_formats/test_custom_cast.py | 16 ++- 2 files changed, 15 insertions(+), 135 deletions(-) delete mode 100644 test/dtypes/test_float6_e3m2.py diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py deleted file mode 100644 index 6f37f0a1a..000000000 --- a/test/dtypes/test_float6_e3m2.py +++ /dev/null @@ -1,134 +0,0 @@ -# import torch -# from torch.testing._internal.common_utils import ( -# TestCase, -# instantiate_parametrized_tests, -# parametrize, -# run_tests, -# ) - -# try: -# import torchao.ops -# except RuntimeError: -# pytest.skip("torchao.ops not available") - - -# from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 - - -# _DTYPES = [torch.float32, torch.float16, torch.bfloat16] -# _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - - -# class TestFloat6E3M2(TestCase): - -# @parametrize("device", _DEVICES) -# @parametrize("dtype", _DTYPES) -# @parametrize( -# "input_output", -# [ -# (0.0, 0b000000), # exact values -# (1.0, 0b001100), # normal numbers -# (1.25, 0b001101), -# (28.0, 0b011111), # max -# (0.1875, 0b000011), # subnormal number -# (0.0625, 0b000001), # min -# (29.0, 0b011111), # normal round down -# (26.0, 0b011110), # normal round to nearest even -# (0.1251, 0b000010), # subnormal round down -# (0.0314, 0b000001), # subnormal round up -# (0.03, 0b000000), # underflow -# ], -# ) -# def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): -# input, output = input_output -# input = torch.tensor(input, device=device, dtype=dtype) -# assert to_float6_e3m2(input, no_bit_packing=True).item() == output - -# @parametrize("device", _DEVICES) -# @parametrize("dtype", _DTYPES) -# def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): -# x = torch.randn(128, 128, device=device, dtype=dtype) -# results_unpacked = to_float6_e3m2(x, no_bit_packing=True) -# results_packed = to_float6_e3m2(x) - -# val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) -# bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 -# bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 -# bits2 = (val2 << 6) | (val3); # 2233 3333 - -# expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) -# assert (results_packed == expected_packed).all() - -# @parametrize("device", _DEVICES) -# @parametrize("shape", [(), (0,), (10,), (20, 20)]) -# def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): -# x = torch.randn(shape, device=device) -# result = to_float6_e3m2(x, no_bit_packing=True) -# assert result.shape == shape - -# @parametrize("device", _DEVICES) -# @parametrize("shape", [(4,), (20, 20)]) -# def test_to_float6_e3m2_bit_packing_shape(self, device, shape): -# x = torch.randn(shape, device=device) -# result = to_float6_e3m2(x) -# assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) - -# @parametrize("device", _DEVICES) -# @parametrize("dtype", _DTYPES) -# @parametrize("no_bit_packing", [False, True]) -# def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): -# x = torch.randn(20, 20, device=device, dtype=dtype) -# expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) - -# to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) -# actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) -# torch.testing.assert_close(actual, expected) - -# @parametrize("device", _DEVICES) -# @parametrize( -# "input_output", -# [ -# (0b000000, 0.0), -# (0b001100, 1.0), -# (0b011111, 28.0), # max -# (0b000001, 0.0625), # min -# (0b001110, 1.5), -# (0b000011, 0.1875), # subnormal -# ], -# ) -# def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): -# input, output = input_output -# input = torch.tensor(input, device=device, dtype=torch.uint8) -# assert from_float6_e3m2(input, no_bit_packing=True).item() == output - -# @parametrize("device", _DEVICES) -# def test_from_float6_e3m2_bit_packing_correctness(self, device): -# x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) -# actual = from_float6_e3m2(x) - -# bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) -# x_unpacked0 = bits0 >> 2 -# x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) -# x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) -# x_unpacked3 = bits2 & 0x3F - -# x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) -# expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) -# torch.testing.assert_close(actual, expected) - -# @parametrize("device", _DEVICES) -# @parametrize("no_bit_packing", [False, True]) -# def test_from_float6_e3m2_compile(self, device, no_bit_packing): -# x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) -# expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) - -# from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) -# actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) -# torch.testing.assert_close(actual, expected) - - -# instantiate_parametrized_tests(TestFloat6E3M2) - - -# if __name__ == "__main__": -# run_tests() diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 85d958633..1138a8eb5 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -387,4 +387,18 @@ def test_fp6_values(dtype_name): raise AssertionError("unsupported") torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) -# TODO: move test/dtypes/test_float6_e3m2.py here + +@pytest.mark.parametrize("device", ["cpu"] + (["cuda" if torch.cuda.is_available() else []])) +@pytest.mark.parametrize( + "f32_val,f6_e3m2_enc", + [ + (29.0, 0b011111), # normal round down + (26.0, 0b011110), # normal round to nearest even + (0.1251, 0b000010), # subnormal round down + (0.0314, 0b000001), # subnormal round up + (0.03, 0b000000), # underflow + ] +) +def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): + f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device)) + assert f6_e3m2_unpacked.item() == f6_e3m2_enc From f8268f03dc5766d770efa8b808d6a66f37f959ff Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Jun 2024 10:45:01 +0800 Subject: [PATCH 08/19] update tests --- test/quantization/test_fp6_llm.py | 27 +++++++---- test/test_ops.py | 76 ++++--------------------------- torchao/quantization/fp6_llm.py | 17 ++++--- 3 files changed, 37 insertions(+), 83 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 635f78765..7ab3c63b2 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -7,9 +7,14 @@ parametrize, run_tests, ) -from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 -from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm -from torchao.ops import prepack_fp6_weight +from torchao.quantization.fp6_llm import ( + to_tc_float6_e3m2, + from_tc_float6_e3m2, + _to_tc_float6_e3m2_ref, + Fp6LlmLinear, + convert_fp6_llm, +) +from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -20,9 +25,9 @@ class TestFp6LlmLinear(TestCase): def test_to_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) - expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8) + expected = _to_tc_float6_e3m2_ref(x) actual = to_tc_float6_e3m2(x) - torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1)) + torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) def test_to_tc_float6_e3m2_compile(self, device): @@ -35,18 +40,20 @@ def test_to_tc_float6_e3m2_compile(self, device): @parametrize("device", _DEVICES) def test_from_tc_float6_e3m2_correctness(self, device): x = torch.randn(256, 64, device=device) - x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6 - actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape) + # quantize and dequantize so that the values are exactly representable in FP6 + x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x)) + + actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x)) torch.testing.assert_close(actual, x) @parametrize("device", _DEVICES) def test_from_tc_float6_e3m2_compile(self, device): M, N = 256, 64 - x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device) + x = torch.randint(256, size=(M, N * 3 // 16), dtype=torch.int32, device=device) - expected = from_tc_float6_e3m2(x, M, N) - actual = torch.compile(from_tc_float6_e3m2)(x, M, N) + expected = from_tc_float6_e3m2(x) + actual = torch.compile(from_tc_float6_e3m2)(x) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/test/test_ops.py b/test/test_ops.py index b20e02938..376db266c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,7 @@ from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck import torchao -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.fp6_llm import from_tc_float6_e3m2 import unittest from parameterized import parameterized import pytest @@ -18,58 +18,12 @@ @pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") @unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") class TestOps(TestCase): - def _create_tensors_with_iou(self, N, iou_thresh): - # force last box to have a pre-defined iou with the first box - # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], - # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, - # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh - # Adjust the threshold upward a bit with the intent of creating - # at least one box that exceeds (barely) the threshold and so - # should be suppressed. - boxes = torch.rand(N, 4) * 100 - boxes[:, 2:] += boxes[:, :2] - boxes[-1, :] = boxes[0, :] - x0, y0, x1, y1 = boxes[-1].tolist() - iou_thresh += 1e-5 - boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh - scores = torch.rand(N) - return boxes, scores - - def _create_fp6_inputs(self, BS: int, OC: int, IC: int): + def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) fp16_scale = torch.rand(OC).half() + 0.5 fp16_activation = torch.rand(BS, IC).half() + 0.5 - return fp6_weight, fp16_scale, fp16_activation - - def test_prepack_fp6_weight(self): - OC = 256 - IC = 256 - fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC) - - # smoke test - torchao.ops.prepack_fp6_weight(fp6_weight) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp16_to_fp6_original(self): - OC = 256 - IC = 256 - fp16_weight = torch.randn((OC, IC), dtype=torch.float16) - - # the original FP16->FP6 kernel checks for overflow/underflow - fp16_weight.clip_(-28.0, 28.0) - fp16_weight[fp16_weight.abs() < 0.0625] = 0.0 - - # smoke test - torchao.ops.fp16_to_fp6_original(fp16_weight) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils) + return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp16act_fp6weight_linear(self): @@ -77,35 +31,25 @@ def test_fp16act_fp6weight_linear(self): OC = 256 IC = 256 splitK = 1 - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) - - fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) - act_cuda = fp16_activation.cuda() - weight_cuda = fp6_weight_packed.cuda() - scale_cuda = fp16_scale.cuda() + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") # smoke test - torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) - - fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) - act_cuda = fp16_activation.cuda() - weight_cuda = fp6_weight_packed.cuda() - scale_cuda = fp16_scale.cuda() + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") - results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) + results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) - fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] - results_fp16 = act_cuda @ fp16_weight.cuda().T + fp16_weight = from_tc_float6_e3m2(fp6_weight, dtype=torch.float16) * fp16_scale[:, None] + results_fp16 = fp16_activation @ fp16_weight.T error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs() diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 4eb20f86f..d7f66bc33 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -26,7 +26,7 @@ def _unpack_4bit(x: Tensor) -> Tensor: # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing # https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: +def _to_tc_float6_e3m2_ref(tensor: Tensor) -> Tensor: assert tensor.ndim == 2 M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) @@ -66,7 +66,7 @@ def _to_tc_float6_e3m2_original(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] tensor_4bit = _pack_4bit(tensor_4bit).view(-1) - return torch.cat([tensor_2bit, tensor_4bit], dim=0) + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1).view(torch.int) # more optimized version of _to_tc_float6_e3m2_original() by merging ops @@ -76,7 +76,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: M, N = tensor.shape assert (M % 64 == 0) and (N % 64 == 0) - tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor) + tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) tensor_fp6 = tensor_fp6.flip(3) @@ -88,7 +88,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) tensor_4bit = _pack_4bit(tensor_4bit.flatten()) - return torch.cat([tensor_2bit, tensor_4bit], dim=0) + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1).view(torch.int) def to_scaled_tc_float6_e3m2(tensor: Tensor) -> tuple[Tensor, Tensor]: @@ -98,11 +98,14 @@ def to_scaled_tc_float6_e3m2(tensor: Tensor) -> tuple[Tensor, Tensor]: return tc_fp6_tensor, scale.reciprocal().half() -def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = torch.float32) -> Tensor: - assert tensor.ndim == 1 +def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.int32 + M = tensor.shape[0] + N = tensor.shape[1] // 3 * 16 assert (M % 64 == 0) and (N % 64 == 0) size_2bit = M * N // 4 size_4bit = M * N // 2 + tensor = tensor.view(-1).view(torch.uint8) assert tensor.numel() == size_2bit + size_4bit tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) @@ -293,7 +296,7 @@ def from_float(cls, linear: nn.Linear): # we have to override this internal method to be able to convert weights to FP6 on the fly. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if state_dict[f"{prefix}weight"].shape == (self.out_features, self.in_features): - fp6_weight, scale = to_scaled_tc_float6_e3m2(state_dict[f"{prefix}weight"]) + fp6_weight, scale = to_scaled_tc_float6_e3m2(state_dict.pop(f"{prefix}weight")) state_dict[f"{prefix}weight"] = fp6_weight state_dict[f"{prefix}scales"] = scale From ebbff674cef2cf4b7d992f10762c8cec64e01b6d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Jun 2024 10:46:25 +0800 Subject: [PATCH 09/19] remove openmp flag --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 014a36832..71d674714 100644 --- a/setup.py +++ b/setup.py @@ -49,12 +49,11 @@ def get_extensions(): use_cuda = torch.cuda.is_available() and CUDA_HOME is not None extension = CUDAExtension if use_cuda else CppExtension - extra_link_args = ["-fopenmp"] + extra_link_args = [] extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", - "-fopenmp", ], "nvcc": [ "-O3" if not debug_mode else "-O0", From 25e4be79044b0a62bec6443e451676b9c32ff474 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Jun 2024 13:12:29 +0800 Subject: [PATCH 10/19] update benchmark script --- benchmarks/benchmark_fp6.py | 60 ++++++++----------------------------- 1 file changed, 13 insertions(+), 47 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index abe21d2f7..000026b5e 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,39 +1,21 @@ import torch -import torchao +from torch import nn +from torchao.quantization.fp6_llm import Fp6LlmLinear from torch.utils.benchmark import Timer import pandas as pd from tqdm import tqdm -def benchmark(m, k, n, splitK): - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (n, k // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(n).half() + 0.5 - fp16_activation = torch.rand(m, k).half() + 0.5 +def benchmark(m: int, k: int, n: int): + fp16_act = torch.randn(m, k, device="cuda", dtype=torch.half) + fp16_linear = nn.Linear(k, n, bias=False, device="cuda", dtype=torch.half) + fp6_linear = Fp6LlmLinear.from_float(fp16_linear) - fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) - act_cuda = fp16_activation.cuda() - weight_cuda = fp6_weight_packed.cuda() - scale_cuda = fp16_scale.cuda() + fp6_output = fp6_linear(fp16_act) + fp16_output = fp16_linear(fp16_act) - # need to do this since Timer cannot see torchao - def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK): - return torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) - - fp6_output = fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK) - - fp6_measurement = Timer( - stmt="fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK)", - globals=locals(), - ).blocked_autorange() - - fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() - fp16_output = act_cuda @ fp16_weight.T - - fp16_measurement = Timer( - stmt="act_cuda @ fp16_weight.T", - globals=locals(), - ).blocked_autorange() + fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange() + fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange() # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness @@ -57,25 +39,9 @@ def fp6_linear(act_cuda, weight_cuda, scale_cuda, splitK): results = [] - # splitK can be tuned based on m, k, n - for m, splitK_vals in tqdm([ - (1, (5, 6, 7, 6)), - (2, (5, 6, 7, 6)), - (4, (5, 6, 7, 6)), - (8, (5, 6, 7, 6)), - # (16, (5, 6, 7, 6)), - # (64, (5, 6, 7, 6)), - # (128, (5, 3, 3, 3)), - # (256, (4, 3, 2, 3)), - # (512, (2, 5, 2, 4)), - (1024, (1, 2, 1, 2)), - (2048, (1, 1, 1, 1)), - (4096, (1, 1, 1, 1)), - # (8192, (1, 1, 1, 1)), - # (16384, (1, 1, 1, 1)), - ]): - for n, k, splitK in zip(n_vals, k_vals, splitK_vals): - results.append(benchmark(m, n, k, splitK)) + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, n, k)) df = pd.DataFrame(results) df.to_csv("fp6_benchmark_results.csv", index=False) From 21dfc6088593308bbdea5411389f882c73cbc794 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 3 Jun 2024 22:14:27 +0800 Subject: [PATCH 11/19] test negative number --- test/prototype/mx_formats/test_custom_cast.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 1138a8eb5..28f33976a 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -402,3 +402,6 @@ def test_fp6_values(dtype_name): def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device)) assert f6_e3m2_unpacked.item() == f6_e3m2_enc + + f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device)) + assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000) From 64e24f74e1557169d0f801ff06bade721ff86749 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 3 Jun 2024 22:17:09 +0800 Subject: [PATCH 12/19] remove qtorch dep --- dev-requirements.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 4d6185874..68b17dc88 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,6 +14,3 @@ tabulate # QOL for printing tables to stdout # Custom CUDA Extensions ninja - -# for FP6-LLM (can be removed once we remove fp16_to_fp6_original()) -qtorch From 6d6f5dd60155c8c33030d8fbaa290f2086daf4ce Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 3 Jun 2024 22:53:10 +0800 Subject: [PATCH 13/19] fix type casting --- test/quantization/test_fp6_llm.py | 2 +- torchao/quantization/fp6_llm.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/quantization/test_fp6_llm.py b/test/quantization/test_fp6_llm.py index 7ab3c63b2..906154d33 100644 --- a/test/quantization/test_fp6_llm.py +++ b/test/quantization/test_fp6_llm.py @@ -50,7 +50,7 @@ def test_from_tc_float6_e3m2_correctness(self, device): @parametrize("device", _DEVICES) def test_from_tc_float6_e3m2_compile(self, device): M, N = 256, 64 - x = torch.randint(256, size=(M, N * 3 // 16), dtype=torch.int32, device=device) + x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device) expected = from_tc_float6_e3m2(x) actual = torch.compile(from_tc_float6_e3m2)(x) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index d7f66bc33..d478e27be 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -66,7 +66,7 @@ def _to_tc_float6_e3m2_ref(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] tensor_4bit = _pack_4bit(tensor_4bit).view(-1) - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1).view(torch.int) + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) # more optimized version of _to_tc_float6_e3m2_original() by merging ops @@ -88,20 +88,19 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) tensor_4bit = _pack_4bit(tensor_4bit.flatten()) - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1).view(torch.int) + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) def to_scaled_tc_float6_e3m2(tensor: Tensor) -> tuple[Tensor, Tensor]: scale = F6_E3M2_MAX / tensor.abs().amax(1).clamp(min=1e-12) tc_fp6_tensor = to_tc_float6_e3m2(tensor * scale.view(-1, 1)) - tc_fp6_tensor = tc_fp6_tensor.view(tensor.shape[0], -1).view(torch.int32) return tc_fp6_tensor, scale.reciprocal().half() def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.int32 + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 M = tensor.shape[0] - N = tensor.shape[1] // 3 * 16 + N = tensor.shape[1] // 3 * 4 assert (M % 64 == 0) and (N % 64 == 0) size_2bit = M * N // 4 size_4bit = M * N // 2 @@ -266,11 +265,11 @@ class Fp6LlmLinear(nn.Module): def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: super().__init__() - self.register_buffer("weight", weight) + self.register_buffer("weight", weight.view(torch.int32)) self.register_buffer("scales", scales) self.register_buffer("bias", bias) self.out_features = weight.shape[0] - self.in_features = weight.shape[1] * 16 // 3 + self.in_features = weight.shape[1] // 3 * 4 def forward(self, x: Tensor) -> Tensor: splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) From 474ebc205481174ce4bfa7cd2aea7d86218088c5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 3 Jun 2024 23:14:02 +0800 Subject: [PATCH 14/19] add view --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 376db266c..016ac24fe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -48,7 +48,7 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) - fp16_weight = from_tc_float6_e3m2(fp6_weight, dtype=torch.float16) * fp16_scale[:, None] + fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] results_fp16 = fp16_activation @ fp16_weight.T error = (results_fp6 - results_fp16).abs() From 509217c408e7561c621e821eeecc5c8e9309cf9d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 9 Jun 2024 23:30:45 +0800 Subject: [PATCH 15/19] fix strange pytest behavior --- test/prototype/mx_formats/test_custom_cast.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 0b0bff57d..d148f595c 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -388,7 +388,13 @@ def test_fp6_values(dtype_name): torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0) -@pytest.mark.parametrize("device", ["cpu"] + (["cuda" if torch.cuda.is_available() else []])) +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")), + ] +) @pytest.mark.parametrize( "f32_val,f6_e3m2_enc", [ From 11dcba3292c6a8654096ef8ddc4d68ac645a5787 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 9 Jun 2024 23:42:01 +0800 Subject: [PATCH 16/19] only skip tests requiring PyTorch 2.4 --- test/prototype/mx_formats/test_custom_cast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index d148f595c..393d5b546 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -46,8 +46,6 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.utils import TORCH_VERSION_AFTER_2_4 -if not TORCH_VERSION_AFTER_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) torch.manual_seed(0) @@ -322,6 +320,7 @@ def test_fp4_pack_unpack(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4") def test_fp4_triton_unscaled_cast(): packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) @@ -331,6 +330,7 @@ def test_fp4_triton_unscaled_cast(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4") def test_fp4_triton_scaled_cast(): size = (256,) orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 From 6f8e7e93677138682e2e322fd78e7f1d0289ba7d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 9 Jun 2024 23:47:01 +0800 Subject: [PATCH 17/19] remove weight loading magic --- torchao/quantization/fp6_llm.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index d478e27be..5d57f1636 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -291,17 +291,6 @@ def from_float(cls, linear: nn.Linear): bias = linear.bias.detach().half() if linear.bias is not None else None return cls(fp6_weight, scale, bias) - # without load_state_dict_pre_hook() https://github.com/pytorch/pytorch/issues/75287 - # we have to override this internal method to be able to convert weights to FP6 on the fly. - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if state_dict[f"{prefix}weight"].shape == (self.out_features, self.in_features): - fp6_weight, scale = to_scaled_tc_float6_e3m2(state_dict.pop(f"{prefix}weight")) - - state_dict[f"{prefix}weight"] = fp6_weight - state_dict[f"{prefix}scales"] = scale - - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def extra_repr(self) -> str: return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' From f454f4dce0568b33c0d44aa4a9dcdb29524339e3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 10 Jun 2024 00:35:55 +0000 Subject: [PATCH 18/19] fix typing tuple --- torchao/quantization/fp6_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 5d57f1636..95c51423b 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, Tuple import torch from torch import nn, Tensor @@ -91,7 +91,7 @@ def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) -def to_scaled_tc_float6_e3m2(tensor: Tensor) -> tuple[Tensor, Tensor]: +def to_scaled_tc_float6_e3m2(tensor: Tensor) -> Tuple[Tensor, Tensor]: scale = F6_E3M2_MAX / tensor.abs().amax(1).clamp(min=1e-12) tc_fp6_tensor = to_tc_float6_e3m2(tensor * scale.view(-1, 1)) return tc_fp6_tensor, scale.reciprocal().half() From fa38572f8ceeed8f33fb2c92553eb51d2cb144b0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 10 Jun 2024 00:48:26 +0000 Subject: [PATCH 19/19] fix list typing --- torchao/quantization/fp6_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 95c51423b..6fa84644f 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch from torch import nn, Tensor @@ -295,7 +295,7 @@ def extra_repr(self) -> str: return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' -def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[list[str]] = None, cur_fqn: str = "") -> None: +def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, cur_fqn: str = "") -> None: for name, child in model.named_children(): new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}"