Skip to content

Commit

Permalink
Change NF4Tensor dtype and add support for linear (pytorch#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Mar 22, 2024
1 parent 2e68045 commit 11e8163
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 3 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@ jobs:
pip install torch
- name: Install package
run: |
pip install .
- name: Run tests
run: |
pytest test
test-nightly:
runs-on: 4-core-ubuntu-gpu-t4
steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
- name: Install package
run: |
pip install .
Expand Down
23 changes: 23 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao

bnb_available = False

Expand Down Expand Up @@ -176,6 +177,28 @@ def test_to_copy(self):
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)

def test_to_bfloat16(self):
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16

def test_smoketest_linear(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out1 = torch.nn.functional.linear(inp, a)
out2 = torch.nn.functional.linear(inp, a_nf4)

@unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.")
def test_smoketest_linear_compile(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)



if __name__ == "__main__":
unittest.main()
44 changes: 41 additions & 3 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,50 @@ def noop_detach(func, *args, **kwargs):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _to_copy(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return func(args[0][0].t()).t()
return args[0][0].get_original_weight().to(args[1]['dtype'])

@implements([torch.ops.aten.to.dtype])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])

@implements([torch.ops.aten.t.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.device,
a.requires_grad)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4)
return b

@implements([torch.ops.aten.mm.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])


@implements(
[
Expand Down Expand Up @@ -160,7 +196,8 @@ def __new__(
tensor_meta.original_shape,
tensor_meta.original_strides,
tensor_meta.storage_offset,
dtype=tensor_meta.dtype,
# Picked some floating dtype, but we need dtype extensibility
dtype=torch.float8_e5m2fnuz,
device=tensor_meta.device,
requires_grad=tensor_meta.requires_grad,
)
Expand Down Expand Up @@ -198,6 +235,7 @@ def from_tensor(
block_size: int,
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2
assert inpt_tensor.dtype == torch.bfloat16
assert (
inpt_tensor.numel() % block_size == 0
Expand Down Expand Up @@ -428,7 +466,7 @@ def quantize_tensor_nearest(
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
# defined in `torch._C.TensorBase`.
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
"""Dequantize a nf4 value to float16 format"""
"""Dequantize a nf4 value to bfloat16 format"""
# return nf4.index_select(0, value)
return nf4[value]

Expand Down Expand Up @@ -546,7 +584,7 @@ class LinearNF4(torch.autograd.Function):
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.get_original_weight())
return F.linear(input, weight.to(input.dtype))

@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
Expand Down

0 comments on commit 11e8163

Please sign in to comment.