Skip to content

Commit

Permalink
Intx Quantization Tensor Class (#468)
Browse files Browse the repository at this point in the history
* init class

* tensor subclasses work but slow?

* fixed frame break

* removed a print

* llama profile added

* perf

* added profile time

* added intx quantization to benchmark scripts

* add tests

* Delete trace.json

* Delete profile.txt

* seperated dtype and affine quant WIP

* works without compile

* seperated stuff, added tests

* remove intx from api til ready

* undo spacing in aqt

* updated torch_dispatch

* updated test

* re-added missing comment

* remove new line

* add new line

* white space fix

* whitespace fix

* fixed test

* refactored implements, actually fixed tests

* tests only run on nightly

* clean up from pr reviews
  • Loading branch information
vayuda authored Aug 7, 2024
1 parent d582f9a commit 87869f2
Show file tree
Hide file tree
Showing 8 changed files with 680 additions and 339 deletions.
93 changes: 0 additions & 93 deletions benchmarks/benchmark_bitpacking.py

This file was deleted.

109 changes: 109 additions & 0 deletions benchmarks/benchmark_uintx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from math import log
from copy import deepcopy

import torch
from torchao.utils import unwrap_tensor_subclass
from torchao.prototype.uintx import uintx_affine_weight_only, pack, unpack, pack_cpu, unpack_cpu
from torchao.quantization.quant_api import quantize_

class Linear16(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(scale*2, scale, bias=True, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale//2, bias=True, dtype=torch.float16).cuda(),
)

def forward(self, x):
return self.net(x)


def benchmark(function, args, num_runs):
# warmup
torch._dynamo.reset()
for i in range(100):
function(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(num_runs):
function(*args)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs


def profile_bitpack():
from torch.profiler import profile, record_function, ProfilerActivity
fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()]
func = torch.compile(unpack_cpu, fullgraph=True)
with profile(activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True
) as prof:

for _ in range(1000):
unpacked = func(fake_tensor, 4)

# Print a summary
with open("profile-bitpack.txt", "a") as f:
print(f'{func}',file=f)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f)
prof.export_chrome_trace("trace.json")
'''
CPU perf:
unpack_gpu
Self CPU time total: 602.501ms
unpack_cpu
Self CPU time total: 415.469ms
GPU perf:
unpack_gpu on gpu:
Self CPU time total: 58.512ms
Self CUDA time total: 5.083ms
unpack_cpu:
Self CPU time total: 96.947ms
Self CUDA time total: 5.253ms
'''

def uintx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024], repeats=30):
results = []
nbits.sort()
scales.sort()
for scale in scales:
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
forward_args = [test_input]
times = [scale]

fp16 = Linear16(scale)
fp16c = torch.compile(fp16, fullgraph=True)
fp16_time = benchmark(fp16c.forward, forward_args, repeats)
times.append(fp16_time)
for bit_size in nbits:
m = deepcopy(fp16)
quantize_(m, uintx_affine_weight_only(bit_size))
m = torch.compile(m, fullgraph=True)
uintx_time = benchmark(m.forward, forward_args, repeats)
times.append(uintx_time)
print(f'scale={scale} done')

results.append(times)
print("----------- benchmark results -----------")
for result in results:
print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:")
for i in range(2, len(result)):
print(f"int{nbits[i-2]}: {result[1]/result[i]: .2f}x")



if __name__ == "__main__":
uintx_vs_fp16(nbits=[4,7])


167 changes: 45 additions & 122 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,66 @@
import torch
from torchao.prototype.common.bitpacking import pack, unpack
from torchao.prototype.uintx import pack, unpack, pack_cpu, unpack_cpu
import pytest
from torch.utils._triton import has_triton
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
dimensions = (2, 1, 0, -1)
orders = (True, False)

element_bit_width = (1,2,3,4,5,6,7)
dimensions = (0, -1, 1)

@pytest.fixture(autouse=True)
def run_before_and_after_tests():
# source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501

# setup (currently do nothing)

# tests will run here
yield
torch._dynamo.reset() # reset cache between tests

# teardown
# avoid dynamo cache limit issues
torch._dynamo.reset()

@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_CPU(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu')
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu')

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
def test_CPU(element_bit_width, dim):
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu')
packed = pack_cpu(test_tensor, element_bit_width, dim = dim)
unpacked = unpack_cpu(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("dtype", dtypes)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_GPU(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
order = order,
dim = dim)
def test_GPU(element_bit_width, dim):
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, element_bit_width, dim = dim)
unpacked = unpack(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_padding(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
torch._dynamo.config.specialize_int = True
shape =[4, 4, 4]
shape[dim] = 5

if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
container_dtype = torch.uint8,
order = order,
pad= True)
assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
slices = [slice(None)] * packed.ndim
slices[dim] = slice(None, 5)
assert unpacked[slices].allclose(test_tensor)



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_compile(dtype, dim, order):
pack_compile = torch.compile(pack, fullgraph=True, dynamic=True)
unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True)
element_bit_width, element_type,expected_pack_size = dtype
def test_compile(element_bit_width, dim):
torch._dynamo.config.specialize_int = True
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda()

packed = pack_compile(test_tensor, element_bit_width,
element_type=element_type,
dim = dim,
container_dtype = torch.int8,
order = order)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack_compile(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
pack_compile = torch.compile(pack, fullgraph=True)
unpack_compile = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, element_bit_width, dim = dim)
unpacked = unpack(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))

# these test cases are for the example pack walk through in the bitpacking.py file
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pack_example():
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda()
shard_4,shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
unpacked = unpack([shard_4, shard_2], 6)
assert unpacked.allclose(test_tensor)

def test_pack_example_CPU():
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8)
shard_4,shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
unpacked = unpack([shard_4, shard_2], 6)
assert unpacked.allclose(test_tensor)


Loading

0 comments on commit 87869f2

Please sign in to comment.