Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TinyGemmQBitsTensor move #246

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion optimum/quanto/library/ext/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,22 @@ def unpack_cuda(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)


@torch.library.impl("quanto_ext::gemm", ["CUDA"])
torch.library.define(
"quanto::gemm",
"(Tensor input,"
" Tensor other,"
" Tensor other_scale,"
" Tensor other_shift,"
" int rows,"
" int out_cols,"
" int in_cols,"
" int bits,"
" int group_size)"
" -> Tensor",
)


@torch.library.impl("quanto::gemm", ["CUDA"])
def gemm_cuda(
input: torch.Tensor,
other: torch.Tensor,
Expand Down
13 changes: 0 additions & 13 deletions optimum/quanto/library/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,3 @@ def impl(*args, **kwargs):


define("unpack", "(Tensor self, int bits) -> Tensor")
define(
"gemm",
"(Tensor input,"
" Tensor other,"
" Tensor other_scale,"
" Tensor other_shift,"
" int rows,"
" int out_cols,"
" int in_cols,"
" int bits,"
" int group_size)"
" -> Tensor",
)
10 changes: 6 additions & 4 deletions optimum/quanto/tensor/qbits/qbits_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def _to_copy(op, t, dtype=None, device=None, **kwargs):
@register_qbitstensor_op([torch.ops.aten.detach])
def detach(op, t):
# Detach is required when copying and deserializing
data = op(t._data)
scale = op(t._scale)
shift = op(t._shift)
return t.__class__(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift)
inner_tensor_names, meta = t.__tensor_flatten__()
# Detach inner tensors
detached_tensors = {}
for inner_name in inner_tensor_names:
detached_tensors[inner_name] = op(getattr(t, inner_name))
return t.__class__.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
23 changes: 0 additions & 23 deletions optimum/quanto/tensor/qbits/tinygemm/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.

import ast
from copy import copy

import torch
from torch.autograd import Function

from ...qtensor import qfallback
from ...qtype import qtypes
from ..group import group, ungroup
from ..qbits import QBitsTensor
Expand Down Expand Up @@ -127,24 +125,3 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
size = ast.literal_eval(meta["size"])
stride = ast.literal_eval(meta["stride"])
return TinyGemmQBitsTensor(qtype, axis, group_size, size, stride, data, scale_shift)

@classmethod
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Do not use directly op, but rather its overload
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
scale_shift = op(t._scale_shift)
return TinyGemmQBitsTensor(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale_shift)
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", None)
if dtype is not None and dtype != t.dtype:
raise ValueError("The dtype of a TinyGemmQBitsTensor cannot be changed")
scale_shift = op(t._scale_shift, **kwargs)
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
data = op(t._data, **data_kwargs)
return TinyGemmQBitsTensor(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale_shift)
# No dispatch available: qfallback
return qfallback(op, *args, **kwargs)
8 changes: 5 additions & 3 deletions test/tensor/ops/test_linear_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def test_linear_gemm_fp16_int4(batch_size, tokens, embeddings, use_bias):
@pytest.mark.parametrize("batch_size", [1, 10])
@pytest.mark.parametrize("tokens, embeddings", [(256, 256)])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_linear_tinygemm(batch_size, tokens, embeddings, use_bias, device):
def test_linear_bf16_int4(batch_size, tokens, embeddings, use_bias, device):
dtype = torch.bfloat16
weight_qtype = qint4
inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device)
qweight = random_qweight((embeddings, embeddings), weight_qtype, dtype=dtype, axis=0, group_size=128).to(device)
input_shape = (batch_size, tokens, embeddings)
inputs = torch.rand(input_shape, dtype=dtype, device=device)
weight_shape = (embeddings, embeddings)
qweight = random_qweight(weight_shape, weight_qtype, dtype=dtype, axis=0, group_size=128, device=device)
bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None
qout = torch.nn.functional.linear(inputs, qweight, bias)
out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias)
Expand Down
4 changes: 4 additions & 0 deletions test/tensor/qbits/test_tinygemm_packed_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("random", [True, False])
def test_pack_tinygemm_tensor(in_features, out_features, random, device):
if device.type == "cuda" and torch.cuda.get_device_capability()[0] < 8:
pytest.skip(reason="CUDA device >= sm80 not available")
bits = 4
qmax = 2**bits
shape = (out_features, in_features)
Expand All @@ -43,6 +45,8 @@ def test_pack_tinygemm_tensor(in_features, out_features, random, device):

@pytest.mark.skip_device("mps") # Only available with pytorch 2.4
def test_move_tinygemm_packed_tensor(device):
if device.type == "cuda" and torch.cuda.get_device_capability()[0] < 8:
pytest.skip(reason="CUDA device >= sm80 not available")
shape = (256, 256)
bits = 4
qmax = 2**bits
Expand Down
2 changes: 2 additions & 0 deletions test/tensor/qbits/test_tinygemm_qbits_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
@pytest.mark.parametrize("in_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
def test_tinygemm_qbits_tensor_from_qbits_tensor(in_features, out_features, device):
if device.type == "cuda" and torch.cuda.get_device_capability()[0] < 8:
pytest.skip(reason="CUDA device >= sm80 not available")
qtype = qint4
group_size = 128
dtype = torch.bfloat16
Expand Down
Loading