Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out the specific configurations to helper functions #286

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 9 additions & 91 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
Quantizer,
TwoStepQuantizer,
quantize,
get_apply_8da4w_quant,
get_apply_int4wo_quant,
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -416,42 +420,11 @@ def test_eval_wrapper(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
# weight settings
groupsize = 32
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()

def apply_weight_quant(weight):
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)

def apply_act_quant(weight):
return to_laq(weight, input_quant_func)

# note: order is important
m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -474,27 +447,13 @@ def apply_act_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def apply_weight_quant(weight):
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

m = quantize(m, apply_weight_quant)
groupsize = 32
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -511,21 +470,11 @@ def apply_weight_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def apply_weight_quant(weight):
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m = quantize(m, apply_weight_quant)
m = quantize(m, get_apply_int8wo_quant())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -543,43 +492,12 @@ def apply_weight_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

def apply_act_quant(weight):
return to_laq(weight, input_quant_func)

m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
m = quantize(m, get_apply_int8dyn_quant())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down
110 changes: 110 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
to_laq,
)

from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
Expand All @@ -56,6 +62,10 @@
"quantize",
"autoquant",
"_get_subclass_inserter",
"get_apply_8da4w_quant",
"get_apply_int4wo_quant",
"get_apply_int8wo_quant",
"get_apply_int8dyn_quant",
]

if TORCH_VERSION_AFTER_2_3:
Expand Down Expand Up @@ -287,3 +297,103 @@ def filter_fn(module, fqn):
_is_linear if filter_fn is None else filter_fn,
)
return model

def get_apply_8da4w_quant(groupsize=32):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make these private for now? I'm not sure if we want users to call these directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine I think, I'm planning to use this in an example as well. what's the concern of this being public?


def apply_8da4w_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_laq(weight, input_quant_func)
return weight

return apply_8da4w_quant


def get_apply_int4wo_quant(groupsize=32):
def apply_int4wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

return apply_int4wo_quant


def get_apply_int8wo_quant():
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return apply_int8wo_quant

def get_apply_int8dyn_quant():
def apply_int8dyn_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_laq(weight, input_quant_func)
return weight
return apply_int8dyn_quant
Loading