Skip to content

Commit

Permalink
Move Uintx out of prototype for future extension
Browse files Browse the repository at this point in the history
Summary:
Thanks @vayuda for adding the initial version of Uintx tensor subclass
we can now integrate this with `torch.uint1` to `torch.uint7` dtypes with some helpers
to unblock the benefit of bitpacking (model size saving) to people first, and then
we can gradually optimize the performance.

Also executorch is planning to integrate their low bit kernels with us, more native experience with
these lower bit types will be required / useful there as well

Test Plan:
python test/dtypes/test_uintx.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 8, 2024
1 parent 1cfe69e commit 40c2dc4
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 128 deletions.
60 changes: 29 additions & 31 deletions test/prototype/test_uintx.py → test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@

import torch

from torchao.prototype.uintx import uintx_affine_weight_only, to_uintx
from torchao.quantization.quant_api import quantize_
from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import TORCH_VERSION_AFTER_2_5

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)

bit_sizes = (1,2,3,4,5,6,7)
group_sizes = [32,64,128]
bit_widths = (1, 2, 3, 4, 5, 6, 7)
group_sizes = [32, 64, 128]
devices = ["cpu", "cuda"]
@pytest.fixture(autouse=True)
def run_before_and_after_tests():
yield
torch._dynamo.reset() # reset cache between tests



class Linear16(torch.nn.Module):
def __init__(self, scale, device):
super().__init__()
Expand All @@ -37,52 +35,52 @@ def __init__(self, scale, device):

def forward(self, x):
return self.net(x)
@pytest.mark.parametrize("bit_size", bit_sizes)

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
def test_uintx_affine_weight_only_model_quant(bit_size, group_size, device):
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
scale = 512
fp16 = Linear16(scale, device)
quantize_(fp16, uintx_affine_weight_only(bit_size, group_size=group_size))
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
uintx = torch.compile(fp16, fullgraph=True)
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
output = uintx.forward(test_input)
assert output != None, "model quantization failed"
@pytest.mark.parametrize("bit_size", bit_sizes)

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
def test_uintx_affine_weight_only_quant(bit_size, group_size, device):
input_float = torch.randn((1,256), dtype=torch.float16, device = device)
def test_uintx_weight_only_quant(bit_width, group_size, device):
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
mapping_type = MappingType.SYMMETRIC
quant_min = 0
quant_max = 2**bit_size - 1
quant_max = 2 ** bit_width - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
target_dtype = torch.uint8
block_size = (1, group_size)

scale, zero_point = choose_qparams_affine(
input_float, mapping_type, block_size,
target_dtype, quant_min, quant_max, eps, torch.float32,
zero_point_dtype, True, zero_point_domain
input_float, mapping_type, block_size,
target_dtype, quant_min, quant_max, eps, torch.float32,
zero_point_dtype, True, zero_point_domain
)

aqt = quantize_affine(
input_float, block_size, scale,
zero_point, target_dtype,
quant_min = quant_min,
quant_max = quant_max,
zero_point_domain = zero_point_domain
)
q = to_uintx(aqt, bit_size, -1)
)

q = to_uintx(aqt, bit_width, -1)
assert q != None, "quantization failed"
deqaunt = dequantize_affine(
q, block_size, scale,
Expand Down
104 changes: 34 additions & 70 deletions torchao/prototype/uintx/Uintx.py → torchao/dtypes/uintx/Uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class UintxTensor(torch.Tensor):
int4_shard (torch.Tensor): 4 bit packed shard
int2_shard (torch.Tensor): 2 bit packed shard
int1_shard (torch.Tensor): 1 bit packed shard
bit_size (int): element size in bits
bit_width (int): number of bits for each element
pack_dim: (int) dimension to pack along
"""
bits_to_shard = {
Expand All @@ -43,71 +43,71 @@ def __new__(
cls,
shards: List[torch.Tensor],
packed_shape: List[int],
bit_size: int,
bit_width: int,
pack_dim: int = -1,
):
kwargs = {"device": shards[0].device}
kwargs["device"] = shards[0].device
kwargs["layout"] = shards[0].layout
kwargs["requires_grad"] = False
kwargs["dtype"] = torch.uint8
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)

def __init__(
self,
shards: List[torch.Tensor],
packed_shape: List[int],
bit_size: int,
bit_width: int,
pack_dim: int = -1,
):
for i, attrib in enumerate(self.bits_to_shard[bit_size]):
for i, attrib in enumerate(self.bits_to_shard[bit_width]):
setattr(self, attrib, shards[i])

self.packed_shape = packed_shape
self.bit_size = bit_size
self.bit_width = bit_width
self.pack_dim = pack_dim

def get_shards(self):
return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_size]]
return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]]

def __repr__(self):
return f"Int{self.bit_size}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_size, dim = self.pack_dim)})"
return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})"

def __tensor_flatten__(self):
return self.__class__.bits_to_shard[self.bit_size], [self.packed_shape, self.bit_size, self.pack_dim]
return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
shards = list(tensor_data_dict.values())
packed_shape, bit_size, pack_dim = tensor_attributes
return cls(shards, packed_shape, bit_size, pack_dim)
packed_shape, bit_width, pack_dim = tensor_attributes
return cls(shards, packed_shape, bit_width, pack_dim)

implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

def get_plain(self):
return unpack(self.get_shards(), self.bit_size, dim = self.pack_dim)
return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_transformation(self, fn):
og = self.get_plain()
new = fn(og)
return self.from_uint8(new, self.bit_size, self.pack_dim)
return self.from_uint8(new, self.bit_width, self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_fn_to_shards(self, fn):
new_shards = [fn(shard) for shard in self.get_shards()]
return self.__class__(new_shards, self.packed_shape, self.bit_size, self.pack_dim)
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)

@classmethod
def from_uint8(cls, int_data: torch.Tensor, bit_size, pack_dim: int = -1):
shards = pack(int_data, bit_size, dim=pack_dim)
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
shards = pack(int_data, bit_width, dim=pack_dim)
shape = list(int_data.shape)
shape[pack_dim] = shape[pack_dim] * bit_size // 8
return cls(shards, int_data.shape, bit_size, pack_dim)
shape[pack_dim] = shape[pack_dim] * bit_width // 8
return cls(shards, int_data.shape, bit_width, pack_dim)


implements = UintxTensor.implements
Expand All @@ -118,19 +118,19 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_fn_to_shards(torch.detach)
)

@implements(aten.view.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:]))
)

@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]
)

@implements(aten.sub.Tensor)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand All @@ -147,18 +147,18 @@ def _(func, types, args, kwargs):

@dataclass(frozen=True)
class UintxLayoutType(LayoutType):
bit_size: int
bit_width: int
pack_dim: int = -1

def post_process(self, input: torch.Tensor) -> torch.Tensor:
return to_uintx(input, self.bit_size, self.pack_dim)
return to_uintx(input, self.bit_width, self.pack_dim)

@register_layout_cls(UintxLayoutType)
class UintxAQTLayout(PlainAQTLayout):

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data.get_plain(), self.scale, self.zero_point

@classmethod
def from_plain(
cls,
Expand All @@ -169,39 +169,3 @@ def from_plain(
):
assert isinstance(layout_type, UintxLayoutType)
return cls(int_data, scale, zero_point, layout_type)


def uintx_affine_weight_only(bit_size, group_size=64, pack_dim=-1):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by the `nbits` argument
"""
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)
from torchao.dtypes import to_affine_quantized
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

layout_type = UintxLayoutType(bit_size=bit_size, pack_dim=pack_dim)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
quant_min = 0
quant_max = 2**bit_size - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT

return to_affine_quantized(
weight, mapping_type, block_size, torch.uint8,
quant_min = quant_min, quant_max = quant_max,
eps = eps, zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
layout_type=layout_type,
)

return _get_linear_subclass_inserter(apply_uintx_weight_only_quant)
Empty file.
Loading

0 comments on commit 40c2dc4

Please sign in to comment.