Skip to content

Commit

Permalink
add back docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jun 17, 2024
1 parent 4ad065f commit bd64efc
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions torchao/prototype/custom_fp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def _n_ones(n: int) -> int:


def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""Convert FP32 numbers to sub-byte floating point numbers with the given
number of exponent and mantissa bits.
Input: torch.Tensor of dtype torch.float
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Note: there are no special values (NaN, inf) support in this code. Values
outside the representable range of FPx after rounding are clamped to the
maximum FPx magnitude (sign is preserved).
Code below is an adaptation of https://fburl.com/code/ciwofcg4
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
"""
assert x.dtype == torch.float
assert 1 + ebits + mbits <= 8

Expand Down Expand Up @@ -122,10 +140,17 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
return x.to(torch.uint8)


# TODO(future): check if LUT for everything is faster than bit shifting,
# especially for fp4 (only 2^4=16 unique values).
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""
TODO(future): check if LUT for everything is faster than bit shifting,
especially for fp4.
"""Convert sub-byte floating point numbers with the given number of exponent
and mantissa bits to FP32.
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Output: torch.Tensor of dtype fp32 with the dequantized value
"""
assert x.dtype == torch.uint8
assert 1 + ebits + mbits <= 8
Expand Down

0 comments on commit bd64efc

Please sign in to comment.