Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150

Merged
merged 47 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0a13e6a
proof of concept for FSDP2 + NF4Tensor
weifengpy Apr 4, 2024
9a56eaa
Merge branch 'main' into main
cpuhrsch Apr 4, 2024
8180540
fsdp extention for tensor subclass
weifengpy Apr 11, 2024
95b03e1
support fp32
weifengpy Apr 15, 2024
3ac9d81
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 16, 2024
38461b3
UNIT TEST FOR STATE DICT
weifengpy Apr 16, 2024
bc7a764
implement to
weifengpy Apr 17, 2024
8b1d037
remove torch.override from torch function
weifengpy Apr 17, 2024
7ff6855
use dtype in compile unit test
weifengpy Apr 17, 2024
d9bcf71
add dtype in all unit test
weifengpy Apr 17, 2024
923bef2
keep original dtype
weifengpy Apr 17, 2024
e15d244
fix linter
weifengpy Apr 17, 2024
d4beb8f
use torch testing @parametrize
weifengpy Apr 17, 2024
f41cb3d
remove unused import
weifengpy Apr 17, 2024
952fbdd
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 17, 2024
950d9fd
sm8 for fp16
weifengpy Apr 17, 2024
d4eae0b
remove sm check for fp16
weifengpy Apr 18, 2024
9444f2c
skip 2.2.2 and below for tracing tensor subclass
weifengpy Apr 18, 2024
b2c3c02
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 18, 2024
9be2de3
include kwargs
weifengpy Apr 19, 2024
2981393
raise unimplemented
weifengpy Apr 19, 2024
3ced998
Merge branch 'main' into main
weifengpy Apr 19, 2024
3f1e19a
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 19, 2024
761416a
fsdp2 ops
weifengpy Apr 19, 2024
c656f1e
better diff layout
weifengpy Apr 19, 2024
c56d7e2
set pg size in metadata
weifengpy Apr 19, 2024
d656b93
remove irrelevant changes
weifengpy Apr 19, 2024
5c4fe2b
add unit test
weifengpy Apr 20, 2024
613bf67
Merge branch 'main' into main
msaroufim Apr 26, 2024
3933bfa
torch.chunk and cpu offloading ops
weifengpy Apr 27, 2024
9e6b4ec
remove strict same metadata check
weifengpy Apr 27, 2024
857b8db
skip tests that needs cuda
weifengpy Apr 27, 2024
8e3de02
use /( in regex match
weifengpy Apr 27, 2024
912998b
fix regex
weifengpy Apr 28, 2024
8926ee1
skip tests if no cuda
weifengpy Apr 28, 2024
6f834ce
skip unit test if no cuda
weifengpy Apr 28, 2024
a8a5aaa
Merge branch 'pytorch:main' into main
weifengpy Apr 28, 2024
699079d
assert cpu device
weifengpy Apr 30, 2024
c8b047c
name args[0] as nf4tensor
weifengpy Apr 30, 2024
925602c
utils for apply to inner tensors and constructor
weifengpy Apr 30, 2024
e36ab6c
use original copy_
weifengpy Apr 30, 2024
a007027
decorator for args check
weifengpy May 1, 2024
c352552
Merge branch 'main' into main
cpuhrsch May 1, 2024
c83fdad
INNER_TENSOR_NAMES_FOR_SHARDING and unify assert in split and new_zeros
weifengpy May 1, 2024
574fecd
Merge branch 'pytorch:main' into main
weifengpy May 1, 2024
f27760b
indicate private constant with _
weifengpy May 1, 2024
b4f51b9
Merge branch 'main' into fsdp2ops
weifengpy May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
from collections import OrderedDict
import torchao
from typing import Tuple, Union


bnb_available = False
Expand Down Expand Up @@ -222,7 +223,98 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
nf4_tensor = to_nf4(torch.randn(input_size))
chunks = list(torch.chunk(nf4_tensor, num_chunks))
self.assertEqual(len(chunks), num_chunks)
if isinstance(input_size, int):
expected_size0 = input_size // num_chunks
else:
expected_size0 = input_size[0] // num_chunks
for chunk in chunks:
self.assertEqual(chunk.size(0), expected_size0)

@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512), (512, 512, 512)])
def test_torch_chunk_invalid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaises(AssertionError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add the message to the assert message, it is not immediately apparent at least to my why these chunks are invalid

Copy link
Contributor Author

@weifengpy weifengpy Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha. it's from NF4Tensor construction and complains about non-dividble block_size . I can add detailed errror msg there

nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size)
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor = getattr(nf4_tensor_zeros, attr)
self.assertEqual(torch.count_nonzero(inner_tensor), 0)
expected_size = input_size if not isinstance(input_size, int) else (input_size, )
self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size))

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
if isinstance(input_size, int):
new_size = input_size + 1
elif len(input_size) == 1:
new_size = (input_size[0] + 1, )
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\(NF4Tensor\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
end_idx = input_size if isinstance(input_size, int) else input_size[0]
sliced_tensor = nf4_tensor[:end_idx]
self.assertEqual(nf4_tensor.size(), sliced_tensor.size())
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
self.assertEqual(getattr(sliced_tensor, attr).untyped_storage().data_ptr(), orig_storage)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you assert that that the correct metadata is on the sliced tensor attributes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. will also assert on metadata


def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end "):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
viewed_tensor = nf4_tensor.view(-1)
self.asssertEqual(viewed_tensor.dim(), 1)
self.asssertEqual(viewed_tensor.numel(), math.prod(input_size))
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
inner_tensor = getattr(sliced_tensor, attr)
self.asssertEqual(inner_tensor.dim(), 1)
self.assertEqual(inner_tensor.untyped_storage().data_ptr(), orig_storage)


# def test_tensor_as_strided(self):
# pass


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)

if __name__ == "__main__":
run_tests()
Loading
Loading