Skip to content

Commit

Permalink
Much lint, so wow (pytorch#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Mar 22, 2024
1 parent 11e8163 commit 2871d74
Show file tree
Hide file tree
Showing 13 changed files with 303 additions and 632 deletions.
2 changes: 2 additions & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
msaroufim
cpuhrsch
132 changes: 53 additions & 79 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,87 @@
from typing import Dict, Tuple

import torch
from torch import Tensor
import torch.nn.functional as F
from torch import Tensor


# pyre-fixme[5]: Global expression must be annotated.
aten = torch.ops.aten
# pyre-fixme[5]: Global expression must be annotated.

c10d_functional = torch.ops.c10d_functional

from typing import Any
# pyre-fixme[5]: Global annotation cannot contain `Any`.

NF4_OPS_TABLE: Dict[Any, Any] = {}


# pyre-fixme[3]: Return type must be annotated.
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
return (
both_nf4 and
a.block_size == b.block_size
both_nf4
and a.block_size == b.block_size
and a.scaler_block_size == b.scaler_block_size
and a.n_blocks == b.n_blocks
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.

def implements(aten_ops):
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""

# pyre-fixme[53]: Captured variable `aten_ops` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def decorator(func):
for op in aten_ops:
NF4_OPS_TABLE[op] = func
return func

return decorator


@implements([torch.ops.aten.detach.default, torch.ops.aten.detach])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def noop_detach(func, *args, **kwargs):
return args[0][0]


@implements([torch.ops.aten._to_copy.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _to_copy(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return func(args[0][0].t()).t()
return args[0][0].get_original_weight().to(args[1]['dtype'])
return args[0][0].get_original_weight().to(args[1]["dtype"])


@implements([torch.ops.aten.to.dtype])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])


@implements([torch.ops.aten.t.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.device,
a.requires_grad)
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.device,
a.requires_grad,
)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4)
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4,
)
return b


@implements([torch.ops.aten.mm.default])
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])

Expand All @@ -101,14 +92,12 @@ def mm_default(func, *args, **kwargs):
aten.copy_.default,
]
)
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def copy_(func, *args, **kwargs):
original: NF4Tensor = args[0][0]
copy_in: torch.Tensor = args[0][1]

# Base Case
# pyre-fixme[6]: For 2nd argument expected `NF4Tensor` but got `Tensor`.

if same_metadata(original, copy_in):
original_tensors = original.__tensor_flatten__()[0]
for tensor_name in original_tensors:
Expand All @@ -117,7 +106,9 @@ def copy_(func, *args, **kwargs):

# Convert Non NF4Tensor into NF4 for copy in
if not isinstance(copy_in, NF4Tensor):
copy_in_nf4 = NF4Tensor.from_tensor(copy_in, original.block_size, original.scaler_block_size)
copy_in_nf4 = NF4Tensor.from_tensor(
copy_in, original.block_size, original.scaler_block_size
)
return original.copy_(copy_in_nf4)

# Other Tensor is not a NF4Tensor
Expand All @@ -127,10 +118,11 @@ def copy_(func, *args, **kwargs):
)
return original.copy_(same_meta_nf4)


@dataclass
class SubclassTensorArgs:
original_shape: torch.Size
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.

original_strides: Tuple
storage_offset: int
dtype: torch.dtype
Expand Down Expand Up @@ -161,7 +153,6 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
class NF4Tensor(torch.Tensor):
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""

# pyre-fixme[3]: Return type must be annotated.
def __new__(
cls,
# Args related for base tensor construction
Expand Down Expand Up @@ -190,7 +181,6 @@ def __new__(
"""

# pyre-fixme[16]: `Tensor` has no attribute `_make_wrapper_subclass`.
nf4tensor = torch.Tensor._make_wrapper_subclass(
cls,
tensor_meta.original_shape,
Expand All @@ -203,7 +193,6 @@ def __new__(
)
return nf4tensor

# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
tensor_meta: SubclassTensorArgs,
Expand All @@ -228,7 +217,6 @@ def __init__(

@classmethod
@torch.no_grad()
# pyre-fixme[3]: Return type must be annotated.
def from_tensor(
cls,
inpt_tensor: torch.Tensor,
Expand Down Expand Up @@ -342,7 +330,6 @@ def double_quantize_scalers(
n_scaler_blocks, scaler_block_size
)

# pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`.
quantization_factor = 256 / (2 * scaler_absmax)
# Length equal to weight numel // block_size
quantized_scaler_blocks = scaler_blocks * quantization_factor
Expand All @@ -352,7 +339,7 @@ def double_quantize_scalers(
# This is needed to make sure that quantization_factor remains a repeated view of n_scaler_blocks
# For some reason the 127/scaler_absmax realizes n_scaler entries when only n_scaler_blocks are needed
# The following will grab the first entry for the n_scaler_blocks which is the same across the scaler_block_size
# pyre-fixme[16]: `float` has no attribute `__getitem__`.

quantization_factor = quantization_factor[:, 0]

return (
Expand Down Expand Up @@ -389,7 +376,6 @@ def dequantize_scalers(

@staticmethod
def convert_to_norm_float_weight(
# pyre-fixme[11]: Annotation `tensor` is not defined as a type.
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor
) -> torch.Tensor:
"""Convert a tensor to the normalized float weight format"""
Expand Down Expand Up @@ -450,7 +436,6 @@ def get_original_weight(self) -> torch.Tensor:

@staticmethod
def quantize_tensor_nearest(
# pyre-fixme[11]: Annotation `float16` is not defined as a type.
value: torch.float16, nf4: torch.Tensor
) -> torch.Tensor:
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
Expand All @@ -461,9 +446,9 @@ def quantize_tensor_nearest(
return closest_nf4

@staticmethod
# pyre-fixme[14]: `dequantize` overrides method defined in `TensorBase`

# inconsistently.
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method

# defined in `torch._C.TensorBase`.
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
"""Dequantize a nf4 value to bfloat16 format"""
Expand All @@ -475,7 +460,7 @@ def unpack(
) -> Tuple[
int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size
]:
# pyre-fixme[7]: Expected `Tuple[int, int, Tensor, Tensor, Tensor, Tensor,

# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
return (
self.block_size,
Expand All @@ -487,15 +472,12 @@ def unpack(
self.quantized_data,
)

# pyre-fixme[14]: `__repr__` overrides method defined in `Tensor` inconsistently.
# pyre-fixme[3]: Return type must be annotated.
def __repr__(self):
return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n"

def __str__(self):
return f"NF4Tensor({self.shape}, {self.block_size})"

# pyre-fixme[3]: Return type must be annotated.
def __tensor_flatten__(self):
tensor_meta = SubclassTensorArgs(
self.shape,
Expand All @@ -520,10 +502,9 @@ def __tensor_flatten__(self):
], ctx

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use

# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
# pyre-fixme[2]: Parameter must be annotated.

def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 5, "Expected 5 inner tensors"
return NF4Tensor(
Expand All @@ -538,28 +519,25 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
inner_tensors["nf4"],
)


# pyre-fixme[3]: Return type must be annotated.
def __str__(self):
return self.to(torch.float32).__str__()

@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args, kwargs=None):
"""TODO we are not supporting torch dispatch at the moment
instead we have created a Autograd.Function to handle the linear
"""
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
# And don't support mixed tensor subclasses. This will trigger the handler for
# the next type in the dispatch list
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.

def allowed_subclasses(type):
return (
issubclass(cls, type)
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
or issubclass(
torch._subclasses.functional_tensor.FunctionalTensor, type
)
)

if not all(allowed_subclasses(t) for t in types):
Expand All @@ -572,25 +550,24 @@ def allowed_subclasses(type):
)

# Do not force the Float8Tensor type on the returned tensor
# pyre-fixme[4]: Attribute must be annotated.

__torch_function__ = torch._C._disabled_torch_function_impl


class LinearNF4(torch.autograd.Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `_SingleLevelFunction`

# inconsistently.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.

def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.to(input.dtype))

@staticmethod
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`

# inconsistently.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.

def backward(ctx, grad_output):
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()"""
weight: NF4Tensor = ctx.nf4_weight
Expand All @@ -606,10 +583,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
"""
return LinearNF4.apply(input, weight)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def to_nf4(tensor,
block_size: int = 64,
scaler_block_size: int = 256):

def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
tensor1 = tensor.to(torch.bfloat16)
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)
Loading

0 comments on commit 2871d74

Please sign in to comment.