Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change NF4Tensor dtype and add support for linear #62

Merged
merged 12 commits into from
Mar 22, 2024
26 changes: 26 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@ jobs:
pip install torch


- name: Install package
run: |
pip install .

- name: Run tests
run: |
pytest test

test-nightly:
runs-on: 4-core-ubuntu-gpu-t4
steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121


- name: Install package
run: |
pip install .
Expand Down
23 changes: 23 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao

bnb_available = False

Expand Down Expand Up @@ -176,6 +177,28 @@ def test_to_copy(self):
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)

def test_to_bfloat16(self):
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check the dtype of inpt_tensor_nf4.to(torch.bfloat16)?

assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16

def test_smoketest_linear(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out1 = torch.nn.functional.linear(inp, a)
out2 = torch.nn.functional.linear(inp, a_nf4)

@unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.")
def test_smoketest_linear_compile(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)



if __name__ == "__main__":
unittest.main()
44 changes: 41 additions & 3 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,50 @@ def noop_detach(func, *args, **kwargs):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _to_copy(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return func(args[0][0].t()).t()
return args[0][0].get_original_weight().to(args[1]['dtype'])

@implements([torch.ops.aten.to.dtype])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])

@implements([torch.ops.aten.t.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.device,
a.requires_grad)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4)
return b

@implements([torch.ops.aten.mm.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])


@implements(
[
Expand Down Expand Up @@ -160,7 +196,8 @@ def __new__(
tensor_meta.original_shape,
tensor_meta.original_strides,
tensor_meta.storage_offset,
dtype=tensor_meta.dtype,
# Picked some floating dtype, but we need dtype extensibility
dtype=torch.float8_e5m2fnuz,
device=tensor_meta.device,
requires_grad=tensor_meta.requires_grad,
)
Expand Down Expand Up @@ -198,6 +235,7 @@ def from_tensor(
block_size: int,
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2
assert inpt_tensor.dtype == torch.bfloat16
assert (
inpt_tensor.numel() % block_size == 0
Expand Down Expand Up @@ -428,7 +466,7 @@ def quantize_tensor_nearest(
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
# defined in `torch._C.TensorBase`.
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
"""Dequantize a nf4 value to float16 format"""
"""Dequantize a nf4 value to bfloat16 format"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is nf4 tensor still restricted to bf16 only for the higher precision, are there any blockers in supporting fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to support arbitrary precision for conversion, but of course the fidelity of nf4 is independen of the dtype that was passed during construction.

# return nf4.index_select(0, value)
return nf4[value]

Expand Down Expand Up @@ -546,7 +584,7 @@ class LinearNF4(torch.autograd.Function):
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.get_original_weight())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg - So we used to dequantize for each linear call? I guess that makes sense since it's essentially weight only quant.

return F.linear(input, weight.to(input.dtype))

@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
Expand Down
Loading