-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change NF4Tensor dtype and add support for linear #62
Changes from all commits
891a1ee
4191263
2f11452
fe16b5b
49febdd
0923ff4
6ab4455
4010cf3
b3363e6
df5ad8d
d566183
292a8ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,14 +51,50 @@ def noop_detach(func, *args, **kwargs): | |
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def _to_copy(func, *args, **kwargs): | ||
if not args[0][0].is_contiguous(): | ||
assert args[0][0].t().is_contiguous() | ||
return func(args[0][0].t()).t() | ||
return args[0][0].get_original_weight().to(args[1]['dtype']) | ||
|
||
@implements([torch.ops.aten.to.dtype]) | ||
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def to_dtype(func, *args, **kwargs): | ||
if not args[0][0].is_contiguous(): | ||
assert args[0][0].t().is_contiguous() | ||
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() | ||
return args[0][0].get_original_weight().to(args[0][1]) | ||
|
||
@implements([torch.ops.aten.t.default]) | ||
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def t_default(func, *args, **kwargs): | ||
a = args[0][0] | ||
tensor_meta = SubclassTensorArgs( | ||
a.size(), | ||
(a.stride(1), a.stride(0)), | ||
a.storage_offset(), | ||
torch.bits2x4, | ||
a.device, | ||
a.requires_grad) | ||
b = NF4Tensor( | ||
tensor_meta, | ||
a.block_size, | ||
a.n_blocks, | ||
a.scaler_block_size, | ||
a.quantized_scalers, | ||
a.quantization_factor, | ||
a.scaler_mean, | ||
a.quantized_data, | ||
a.nf4) | ||
return b | ||
|
||
@implements([torch.ops.aten.mm.default]) | ||
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def mm_default(func, *args, **kwargs): | ||
return linear_nf4(args[0][0], args[0][1]) | ||
|
||
|
||
@implements( | ||
[ | ||
|
@@ -160,7 +196,8 @@ def __new__( | |
tensor_meta.original_shape, | ||
tensor_meta.original_strides, | ||
tensor_meta.storage_offset, | ||
dtype=tensor_meta.dtype, | ||
# Picked some floating dtype, but we need dtype extensibility | ||
dtype=torch.float8_e5m2fnuz, | ||
device=tensor_meta.device, | ||
requires_grad=tensor_meta.requires_grad, | ||
) | ||
|
@@ -198,6 +235,7 @@ def from_tensor( | |
block_size: int, | ||
scaler_block_size: int, | ||
): | ||
assert inpt_tensor.dim() <= 2 | ||
assert inpt_tensor.dtype == torch.bfloat16 | ||
assert ( | ||
inpt_tensor.numel() % block_size == 0 | ||
|
@@ -428,7 +466,7 @@ def quantize_tensor_nearest( | |
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method | ||
# defined in `torch._C.TensorBase`. | ||
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor: | ||
"""Dequantize a nf4 value to float16 format""" | ||
"""Dequantize a nf4 value to bfloat16 format""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is nf4 tensor still restricted to bf16 only for the higher precision, are there any blockers in supporting fp32? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to support arbitrary precision for conversion, but of course the fidelity of nf4 is independen of the dtype that was passed during construction. |
||
# return nf4.index_select(0, value) | ||
return nf4[value] | ||
|
||
|
@@ -546,7 +584,7 @@ class LinearNF4(torch.autograd.Function): | |
def forward(ctx, input: torch.Tensor, weight: NF4Tensor): | ||
"""Save the quantized nf4 weight for backward pass""" | ||
ctx.nf4_weight = weight | ||
return F.linear(input, weight.get_original_weight()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @drisspg - So we used to dequantize for each linear call? I guess that makes sense since it's essentially weight only quant. |
||
return F.linear(input, weight.to(input.dtype)) | ||
|
||
@staticmethod | ||
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also check the dtype of
inpt_tensor_nf4.to(torch.bfloat16)
?