Skip to content

Commit

Permalink
Add sparse marlin AQT layout (#621)
Browse files Browse the repository at this point in the history
* feat: starting layout implementation

fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2

* compile kind of working

* fix: batching and layout outputs correct results

* fix: torch.compile

* wip

* feat: wip

* chore: cleanup

* chore: review

* chore: review v2

* update benchmarks + README

---------

Co-authored-by: Jesse Cai <[email protected]>
  • Loading branch information
2 people authored and andrewor14 committed Sep 6, 2024
1 parent e2f5702 commit 28c3f28
Show file tree
Hide file tree
Showing 20 changed files with 538 additions and 102 deletions.
115 changes: 115 additions & 0 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import copy
import pytest

from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes import MarlinSparseLayoutType
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.sparsity.marlin import (
pack_to_marlin_24,
unpack_from_marlin_24,
inject_24
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
ZeroPointDomain,
MappingType,
)


class SparseMarlin24(TestCase):

def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
nn.ReLU(),
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
group_size = 128
shape = (11008, 4096)
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity mask
w_24, _ = inject_24(w, *w.shape)

# Quantize weights
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
w_q_24, scales, num_bits, group_size
)
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
q_w_comp, packed_scales, meta, shape, group_size, num_bits
)

assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"


if __name__ == "__main__":
run_tests()
30 changes: 27 additions & 3 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
int8_dynamic_activation_int8_semi_sparse_weight,
semi_sparse_weight,
)
from torchao.dtypes import MarlinSparseLayoutType
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
int8_dynamic_activation_int8_weight,
quantize_,
int4_weight_only,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase
Expand Down Expand Up @@ -73,5 +72,30 @@ def test_quant_semi_sparse(self):

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse_marlin(self):
input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
nn.Linear(256, 1024),
nn.Linear(1024, 256),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

if __name__ == "__main__":
unittest.main()
24 changes: 15 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
)


MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [
Expand All @@ -318,8 +319,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
Expand Down Expand Up @@ -374,15 +375,15 @@ def reshape_w(w):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity
Expand All @@ -391,19 +392,24 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
# Symmetric quantize
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Obtains reference output
output_ref = torch.matmul(a_input, w_24_ref)
output_ref = torch.matmul(input_2d, w_24_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],))

# Packs to marlin 2:4
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
torch.cuda.synchronize()
output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
Expand Down
1 change: 1 addition & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ kv cache quantization:
20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072
20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization
20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask
20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
4 changes: 3 additions & 1 deletion torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt

# sparse marlin (NOTE: float16)
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
# auto-round w/ quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
# auto-round w/o quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0



export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization
Expand Down
3 changes: 3 additions & 0 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def main(
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer
Expand Down
2 changes: 2 additions & 0 deletions torchao/_models/sam/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
# int8 dynamic quant attn + int4 wo + sparse marlin lin 1 + 2:4 sparse lin2
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half float16 --device cuda --compress int4_weight_only_sparse
34 changes: 24 additions & 10 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def run(
for block in predictor.model.image_encoder.blocks:
block.attn.use_rel_pos = use_rel_pos

# Helper filter functions
def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
def mlp_lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

if compress == "int8_dynamic_quant":
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -296,15 +306,6 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder)
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
elif compress == "int8_dynamic_quant_sparse":
def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
def mlp_lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
Expand All @@ -320,7 +321,20 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress == "int4_weight_only_sparse":
# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
from torchao.dtypes import MarlinSparseLayoutType
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
sparsify_(predictor.model.image_encoder
semi_sparse_weight(),
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
1 change: 1 addition & 0 deletions torchao/_models/sam/results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::marlin_24_gemm", &marlin_24_gemm);
}

} // namespace torchao
} // namespace torchao
2 changes: 1 addition & 1 deletion torchao/csrc/sparse_marlin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor");
}
}
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
MarlinSparseLayoutType,
)

__all__ = [
Expand All @@ -33,4 +34,5 @@
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
"MarlinSparseLayoutType",
]
Loading

0 comments on commit 28c3f28

Please sign in to comment.