-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Intx Quantization Tensor Class (#468)
* init class * tensor subclasses work but slow? * fixed frame break * removed a print * llama profile added * perf * added profile time * added intx quantization to benchmark scripts * add tests * Delete trace.json * Delete profile.txt * seperated dtype and affine quant WIP * works without compile * seperated stuff, added tests * remove intx from api til ready * undo spacing in aqt * updated torch_dispatch * updated test * re-added missing comment * remove new line * add new line * white space fix * whitespace fix * fixed test * refactored implements, actually fixed tests * tests only run on nightly * clean up from pr reviews
- Loading branch information
Showing
8 changed files
with
680 additions
and
339 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from math import log | ||
from copy import deepcopy | ||
|
||
import torch | ||
from torchao.utils import unwrap_tensor_subclass | ||
from torchao.prototype.uintx import uintx_affine_weight_only, pack, unpack, pack_cpu, unpack_cpu | ||
from torchao.quantization.quant_api import quantize_ | ||
|
||
class Linear16(torch.nn.Module): | ||
def __init__(self, scale): | ||
super().__init__() | ||
self.net = torch.nn.Sequential( | ||
torch.nn.Linear(scale*2, scale, bias=True, dtype=torch.float16).cuda(), | ||
torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(), | ||
torch.nn.Linear(scale, scale//2, bias=True, dtype=torch.float16).cuda(), | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
def benchmark(function, args, num_runs): | ||
# warmup | ||
torch._dynamo.reset() | ||
for i in range(100): | ||
function(*args) | ||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
for _ in range(num_runs): | ||
function(*args) | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / num_runs | ||
|
||
|
||
def profile_bitpack(): | ||
from torch.profiler import profile, record_function, ProfilerActivity | ||
fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()] | ||
func = torch.compile(unpack_cpu, fullgraph=True) | ||
with profile(activities=[ | ||
ProfilerActivity.CPU, | ||
ProfilerActivity.CUDA], | ||
record_shapes=True, | ||
with_stack=True | ||
) as prof: | ||
|
||
for _ in range(1000): | ||
unpacked = func(fake_tensor, 4) | ||
|
||
# Print a summary | ||
with open("profile-bitpack.txt", "a") as f: | ||
print(f'{func}',file=f) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f) | ||
prof.export_chrome_trace("trace.json") | ||
''' | ||
CPU perf: | ||
unpack_gpu | ||
Self CPU time total: 602.501ms | ||
unpack_cpu | ||
Self CPU time total: 415.469ms | ||
GPU perf: | ||
unpack_gpu on gpu: | ||
Self CPU time total: 58.512ms | ||
Self CUDA time total: 5.083ms | ||
unpack_cpu: | ||
Self CPU time total: 96.947ms | ||
Self CUDA time total: 5.253ms | ||
''' | ||
|
||
def uintx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024], repeats=30): | ||
results = [] | ||
nbits.sort() | ||
scales.sort() | ||
for scale in scales: | ||
test_input = torch.randn(scale*2, dtype=torch.float16).cuda() | ||
forward_args = [test_input] | ||
times = [scale] | ||
|
||
fp16 = Linear16(scale) | ||
fp16c = torch.compile(fp16, fullgraph=True) | ||
fp16_time = benchmark(fp16c.forward, forward_args, repeats) | ||
times.append(fp16_time) | ||
for bit_size in nbits: | ||
m = deepcopy(fp16) | ||
quantize_(m, uintx_affine_weight_only(bit_size)) | ||
m = torch.compile(m, fullgraph=True) | ||
uintx_time = benchmark(m.forward, forward_args, repeats) | ||
times.append(uintx_time) | ||
print(f'scale={scale} done') | ||
|
||
results.append(times) | ||
print("----------- benchmark results -----------") | ||
for result in results: | ||
print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:") | ||
for i in range(2, len(result)): | ||
print(f"int{nbits[i-2]}: {result[1]/result[i]: .2f}x") | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
uintx_vs_fp16(nbits=[4,7]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,143 +1,66 @@ | ||
import torch | ||
from torchao.prototype.common.bitpacking import pack, unpack | ||
from torchao.prototype.uintx import pack, unpack, pack_cpu, unpack_cpu | ||
import pytest | ||
from torch.utils._triton import has_triton | ||
from torchao.utils import TORCH_VERSION_AFTER_2_4 | ||
|
||
if not TORCH_VERSION_AFTER_2_4: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) | ||
dimensions = (2, 1, 0, -1) | ||
orders = (True, False) | ||
|
||
element_bit_width = (1,2,3,4,5,6,7) | ||
dimensions = (0, -1, 1) | ||
|
||
@pytest.fixture(autouse=True) | ||
def run_before_and_after_tests(): | ||
# source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 | ||
|
||
# setup (currently do nothing) | ||
|
||
# tests will run here | ||
yield | ||
torch._dynamo.reset() # reset cache between tests | ||
|
||
# teardown | ||
# avoid dynamo cache limit issues | ||
torch._dynamo.reset() | ||
|
||
@pytest.mark.parametrize("dtype", dtypes) | ||
@pytest.mark.parametrize("element_bit_width", element_bit_width) | ||
@pytest.mark.parametrize("dim", dimensions) | ||
@pytest.mark.parametrize("order", orders) | ||
def test_CPU(dtype, dim, order): | ||
element_bit_width, element_type,expected_pack_size = dtype | ||
shape = [4, 4, 4] | ||
if element_type == "trinary": | ||
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu') | ||
else: | ||
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu') | ||
|
||
packed = pack(test_tensor, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
order = order, | ||
container_dtype = torch.uint8) | ||
assert(packed.shape[dim] == expected_pack_size) | ||
unpacked = unpack(packed, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
order = order) | ||
def test_CPU(element_bit_width, dim): | ||
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu') | ||
packed = pack_cpu(test_tensor, element_bit_width, dim = dim) | ||
unpacked = unpack_cpu(packed, element_bit_width, dim = dim) | ||
assert(unpacked.allclose(test_tensor)) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize("dtype", dtypes) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.parametrize("element_bit_width", element_bit_width) | ||
@pytest.mark.parametrize("dim", dimensions) | ||
@pytest.mark.parametrize("order", orders) | ||
def test_GPU(dtype, dim, order): | ||
element_bit_width, element_type,expected_pack_size = dtype | ||
shape = [4, 4, 4] | ||
if element_type == "trinary": | ||
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() | ||
else: | ||
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() | ||
|
||
packed = pack(test_tensor, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
order = order, | ||
container_dtype = torch.uint8) | ||
assert(packed.shape[dim] == expected_pack_size) | ||
unpacked = unpack(packed, | ||
element_bit_width, | ||
element_type=element_type, | ||
order = order, | ||
dim = dim) | ||
def test_GPU(element_bit_width, dim): | ||
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, element_bit_width, dim = dim) | ||
unpacked = unpack(packed, element_bit_width, dim = dim) | ||
assert(unpacked.allclose(test_tensor)) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
@pytest.mark.parametrize("dtype", dtypes) | ||
@pytest.mark.parametrize("dim", dimensions) | ||
@pytest.mark.parametrize("order", orders) | ||
def test_padding(dtype, dim, order): | ||
element_bit_width, element_type,expected_pack_size = dtype | ||
torch._dynamo.config.specialize_int = True | ||
shape =[4, 4, 4] | ||
shape[dim] = 5 | ||
|
||
if element_type == "trinary": | ||
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() | ||
else: | ||
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() | ||
|
||
packed = pack(test_tensor, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
container_dtype = torch.uint8, | ||
order = order, | ||
pad= True) | ||
assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario | ||
unpacked = unpack(packed, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
order = order) | ||
slices = [slice(None)] * packed.ndim | ||
slices[dim] = slice(None, 5) | ||
assert unpacked[slices].allclose(test_tensor) | ||
|
||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
@pytest.mark.parametrize("dtype", dtypes) | ||
@pytest.mark.parametrize("element_bit_width", element_bit_width) | ||
@pytest.mark.parametrize("dim", dimensions) | ||
@pytest.mark.parametrize("order", orders) | ||
def test_compile(dtype, dim, order): | ||
pack_compile = torch.compile(pack, fullgraph=True, dynamic=True) | ||
unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True) | ||
element_bit_width, element_type,expected_pack_size = dtype | ||
def test_compile(element_bit_width, dim): | ||
torch._dynamo.config.specialize_int = True | ||
shape = [4, 4, 4] | ||
if element_type == "trinary": | ||
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() | ||
else: | ||
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda() | ||
|
||
packed = pack_compile(test_tensor, element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
container_dtype = torch.int8, | ||
order = order) | ||
assert(packed.shape[dim] == expected_pack_size) | ||
unpacked = unpack_compile(packed, | ||
element_bit_width, | ||
element_type=element_type, | ||
dim = dim, | ||
order = order) | ||
pack_compile = torch.compile(pack, fullgraph=True) | ||
unpack_compile = torch.compile(unpack, fullgraph=True) | ||
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, element_bit_width, dim = dim) | ||
unpacked = unpack(packed, element_bit_width, dim = dim) | ||
assert(unpacked.allclose(test_tensor)) | ||
|
||
# these test cases are for the example pack walk through in the bitpacking.py file | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_pack_example(): | ||
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda() | ||
shard_4,shard_2 = pack(test_tensor, 6) | ||
print(shard_4, shard_2) | ||
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) | ||
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) | ||
unpacked = unpack([shard_4, shard_2], 6) | ||
assert unpacked.allclose(test_tensor) | ||
|
||
def test_pack_example_CPU(): | ||
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8) | ||
shard_4,shard_2 = pack(test_tensor, 6) | ||
print(shard_4, shard_2) | ||
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) | ||
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) | ||
unpacked = unpack([shard_4, shard_2], 6) | ||
assert unpacked.allclose(test_tensor) | ||
|
||
|
Oops, something went wrong.