Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intx Quantization Tensor Class #468

Merged
merged 37 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
72df99d
init class
vayuda Jun 28, 2024
66150a7
tensor subclasses work but slow?
vayuda Jun 28, 2024
ad2be7e
fixed frame break
vayuda Jun 28, 2024
236dbac
removed a print
vayuda Jun 28, 2024
2ababfe
llama profile added
vayuda Jun 29, 2024
72d32cf
perf
vayuda Jul 1, 2024
c01110a
added profile time
vayuda Jul 1, 2024
692bc94
Merge branch 'pytorch:main' into intx
vayuda Jul 1, 2024
0401468
added intx quantization to benchmark scripts
vayuda Jul 2, 2024
df9bac5
add tests
vayuda Jul 2, 2024
803cf4c
Merge branch 'intx' of https://github.com/vayuda/ao into intx
vayuda Jul 2, 2024
342e325
Merge branch 'pytorch:main' into intx
vayuda Jul 2, 2024
5acc5b4
Delete trace.json
vayuda Jul 2, 2024
94481a3
Delete profile.txt
vayuda Jul 2, 2024
5928394
seperated dtype and affine quant WIP
vayuda Jul 11, 2024
ab171ff
Merge branch 'intx' of https://github.com/vayuda/ao into intx
vayuda Jul 11, 2024
5c76d22
Merge branch 'main' of https://github.com/pytorch/ao into intx
vayuda Jul 11, 2024
e058179
Merge branch 'pytorch:main' into intx
vayuda Jul 17, 2024
98309f0
works without compile
vayuda Jul 17, 2024
b8b5e2a
seperated stuff, added tests
vayuda Aug 2, 2024
c0eecbf
Merge branch 'main' into intx
vayuda Aug 2, 2024
d13952c
remove intx from api til ready
vayuda Aug 2, 2024
be488ec
undo spacing in aqt
vayuda Aug 2, 2024
7f48cf1
Merge branch 'main' into intx
vayuda Aug 5, 2024
eef2013
Merge branch 'pytorch:main' into intx
vayuda Aug 6, 2024
2b3cd2e
updated torch_dispatch
vayuda Aug 6, 2024
4fbaee8
Merge branch 'intx' of https://github.com/vayuda/ao into intx
vayuda Aug 6, 2024
8028039
updated test
vayuda Aug 6, 2024
983d964
re-added missing comment
vayuda Aug 6, 2024
a27e06d
remove new line
vayuda Aug 6, 2024
a86c2a8
add new line
vayuda Aug 6, 2024
de5d6ee
white space fix
vayuda Aug 6, 2024
f9b4a6a
whitespace fix
vayuda Aug 6, 2024
87300f6
fixed test
vayuda Aug 6, 2024
5868576
refactored implements, actually fixed tests
vayuda Aug 6, 2024
398701b
tests only run on nightly
vayuda Aug 6, 2024
1505bca
clean up from pr reviews
vayuda Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions benchmarks/benchmark_bitpacking.py

This file was deleted.

126 changes: 126 additions & 0 deletions benchmarks/benchmark_intx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from math import log
from copy import deepcopy

import torch
from torch.profiler import profile, record_function, ProfilerActivity

from torchao.prototype.intx.bitpacking import pack, unpack, pack_cpu, unpack_cpu
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4
from torchao.quantization.quant_api import quantize, intx_weight_only, int8_weight_only


def benchmark(function, args, num_runs):
# warmup
for i in range(100):
function(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(num_runs):
function(*args)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profile_function(function, args, num_runs):
function(*args)
torch.cuda.synchronize()
with profile(activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with record_function("model_inference"):
for _ in range(num_runs):
function(*args)

# Print a summary
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))

def profile_bitpack():

fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()]
func = torch.compile(unpack_cpu, fullgraph=True)
with profile(activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True
) as prof:

for _ in range(1000):
unpacked = func(fake_tensor, 4)

# Print a summary
with open("profile-bitpack.txt", "a") as f:
print(f'{func}',file=f)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f)
prof.export_chrome_trace("trace.json")
'''
gpu unpack on cpu Self CPU time total: 602.501ms
cpu unpack on cpu Self CPU time total: 415.469ms

gpu unpack on gpu:
Self CPU time total: 58.512ms
Self CUDA time total: 5.083ms

cpu unpack on gpu:
Self CPU time total: 96.947ms
Self CUDA time total: 5.253ms
'''

def intx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024],layouts=["plain","packed"], repeats=30):
class Linear16(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(scale * 2, scale, bias=False, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale, bias=False, dtype=torch.float16).cuda(),
torch.nn.Linear(scale, scale//2, bias=False, dtype=torch.float16).cuda(),
)

def forward(self, x):
return self.net(x)

results = []
nbits.sort()
scales.sort()
for scale in scales:
print("scale: ", scale)
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
forward_args = [test_input]
times = [scale]

fp16 = Linear16(scale)
fp16c = torch.compile(fp16, fullgraph=True)
fp16_time = benchmark(fp16c.forward, forward_args, repeats)
times.append(fp16_time)
print('fp16 done')
for bit_size in nbits:
for layout in layouts:
intx = deepcopy(fp16)
intx = quantize(intx, intx_weight_only(bit_size, group_size=64, layout=layout))
intx = torch.compile(intx, fullgraph=True)
intx_time = benchmark(intx.forward, forward_args, repeats)
times.append(intx_time)
print(f'int{bit_size}-{layout} done')
torch._dynamo.reset()
results.append(times)
print("----------- benchmark results -----------")
for result in results:
result_str = "\n".join([f"int{nbits[i]}: {result[1]/result[2+len(layouts)*i]:.2f}x\t{result[1]/result[3+len(layouts)*i]:.2f}x\t{result[1]/result[4+i]:.2f}x" for i in range(len(nbits))])
print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms {layouts} speedups:\n{result_str}")



if __name__ == "__main__":
# test_bitpack_iso()
# profile_intx(4)
# profile_intx(6)
# profile_intx()
intx_vs_fp16(nbits=[5,6,7],scales=[4096], layouts = ["plain","packed"], repeats =10000)
147 changes: 24 additions & 123 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,44 @@
import torch
from torchao.prototype.common.bitpacking import pack, unpack
from torchao.prototype.intx.bitpacking import pack, unpack
import pytest
from torch.utils._triton import has_triton
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
dimensions = (2, 1, 0, -1)
orders = (True, False)

element_bit_width = (1,2,3,4,5,6,7)
dimensions = (0, -1, 1)

@pytest.fixture(autouse=True)
def run_before_and_after_tests():
# source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501

# setup (currently do nothing)

# tests will run here
yield
torch._dynamo.reset() # reset cache between tests

# teardown
# avoid dynamo cache limit issues
torch._dynamo.reset()

@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_CPU(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu')
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu')

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
def test_CPU(element_bit_width, dim):
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu')
packed = pack(test_tensor, element_bit_width, dim = dim)
unpacked = unpack(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("dtype", dtypes)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_GPU(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
order = order,
container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
order = order,
dim = dim)
def test_GPU(element_bit_width, dim):
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, element_bit_width, dim = dim)
unpacked = unpack(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_padding(dtype, dim, order):
element_bit_width, element_type,expected_pack_size = dtype
torch._dynamo.config.specialize_int = True
shape =[4, 4, 4]
shape[dim] = 5

if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor,
element_bit_width,
element_type=element_type,
dim = dim,
container_dtype = torch.uint8,
order = order,
pad= True)
assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario
unpacked = unpack(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
slices = [slice(None)] * packed.ndim
slices[dim] = slice(None, 5)
assert unpacked[slices].allclose(test_tensor)



@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("element_bit_width", element_bit_width)
@pytest.mark.parametrize("dim", dimensions)
@pytest.mark.parametrize("order", orders)
def test_compile(dtype, dim, order):
pack_compile = torch.compile(pack, fullgraph=True, dynamic=True)
unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True)
element_bit_width, element_type,expected_pack_size = dtype
def test_compile(element_bit_width, dim):
torch._dynamo.config.specialize_int = True
shape = [4, 4, 4]
if element_type == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda()

packed = pack_compile(test_tensor, element_bit_width,
element_type=element_type,
dim = dim,
container_dtype = torch.int8,
order = order)
assert(packed.shape[dim] == expected_pack_size)
unpacked = unpack_compile(packed,
element_bit_width,
element_type=element_type,
dim = dim,
order = order)
assert(unpacked.allclose(test_tensor))
pack_compile = torch.compile(pack, fullgraph=True)
unpack_compile = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
packed = pack(test_tensor, element_bit_width, dim = dim)
unpacked = unpack(packed, element_bit_width, dim = dim)
assert(unpacked.allclose(test_tensor))
Loading
Loading