Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change NF4Tensor dtype and add support for linear #62

Merged
merged 12 commits into from
Mar 22, 2024
6 changes: 6 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def test_to_copy(self):
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)

def test_to_bfloat16(self):
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check the dtype of inpt_tensor_nf4.to(torch.bfloat16)?



if __name__ == "__main__":
unittest.main()
38 changes: 36 additions & 2 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def _to_copy(func, *args, **kwargs):
def to_dtype(func, *args, **kwargs):
return args[0][0].get_original_weight().to(args[0][1])

@implements([torch.ops.aten.t.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
a.stride(),
a.storage_offset(),
torch.bits2x4,
a.device,
a.requires_grad)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4,
not a.transpose)
return b

@implements([torch.ops.aten.mm.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])


@implements(
[
Expand Down Expand Up @@ -139,6 +170,7 @@ def __new__(
scaler_mean: torch.Tensor,
quantized_data: torch.Tensor,
nf4: torch.Tensor,
transpose=False,
):
"""Create a new NF4Tensor object
Args:
Expand All @@ -160,7 +192,7 @@ def __new__(
tensor_meta.original_shape,
tensor_meta.original_strides,
tensor_meta.storage_offset,
dtype=tensor_meta.dtype,
dtype=torch.bits2x4,
Copy link
Contributor

@drisspg drisspg Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for provenance I still dont like this 🙃

I think that nf4tensor's outer wrapper subclass should have the same dtype as the type that it was created from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We need a better extensibility story for dtypes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we want to deprecate these, why not use torch.uint2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nf4 is a 4bit type. I suppose another mitigation is a type guard at torch_dispatch level and using torch.bits8 just so the allocator will always spit out bytes (not like it has a choice).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.bits2x4 means 8 bit though, these dtypes (including bits1x8, bits4x2) should be removed actually, since torch.bits8 means the same thing because the meaning is uninterpreted dtypes

so what are you trying to express here? 2 bits * 2 that packed into a 4 bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds like a uint4 tensor with a different packing format, can you reuse uint4 Tensor as the underlying dtype (by inheriting from UInt4Tensor probably)? can you write down all the use cases for nf4 dtype as well so we get some idea of how we can support it?

bits8 is generally not recommended right now either btw, since all these bit shifting ops etc. are already available in uint8 so we'd recommend uint8 if you want a 8 bit dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for provenance I still dont like this 🙃. I think that nf4tensor's outer wrapper subclass should have the same dtype as the type that it was created from.

I agree. Having this represent the high precision dtype has worked well for Float8Tensor.

Copy link
Contributor

@drisspg drisspg Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah Uint4Tensor is the same as NF4Tensor, AFAIK I think ed copied the packing format from NFTensor in nuggets and that was the basis of uint4tensor.

Nf4Tensor was copied over to ao and not inherited for speed of enabling torchtune. But I agree that NF4 should like inherit from uint4

That being said this same outer tensor dtype issue applies the same for the uint4tensor same as it does this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Float8Tensor's dtype is bfloat16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. For float8 tensor specifically, this is required, because we need to trick autograd's assert x.dtype == x.grad.dtype restriction. But it's also conceptually simple to reason about, "this is an emulation of a bfloat16 tensor with scaled float8".

device=tensor_meta.device,
requires_grad=tensor_meta.requires_grad,
)
Expand All @@ -178,6 +210,7 @@ def __init__(
scaler_mean: torch.Tensor,
quantized_data: torch.Tensor,
nf4: torch.Tensor,
transpose=False,
):
"""Initialize the NF4Tensor class"""
self.block_size = block_size
Expand All @@ -188,6 +221,7 @@ def __init__(
self.scaler_mean = scaler_mean
self.quantized_data = quantized_data
self.nf4 = nf4
self.transpose = transpose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe I'll put this into the strides instead and rely on is_contiguous instead.


@classmethod
@torch.no_grad()
Expand Down Expand Up @@ -546,7 +580,7 @@ class LinearNF4(torch.autograd.Function):
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.get_original_weight())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg - So we used to dequantize for each linear call? I guess that makes sense since it's essentially weight only quant.

return F.linear(input, weight.to(input.dtype))

@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
Expand Down
Loading