Skip to content

Commit

Permalink
[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops (pytorch#150
Browse files Browse the repository at this point in the history
)
  • Loading branch information
weifengpy authored May 1, 2024
1 parent ae1628b commit a049baf
Show file tree
Hide file tree
Showing 2 changed files with 515 additions and 8 deletions.
189 changes: 186 additions & 3 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import unittest
from packaging import version
import math

import torch
from torch import nn
Expand All @@ -10,11 +11,17 @@
parametrize,
run_tests,
)
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
from torchao.dtypes.nf4tensor import (
linear_nf4,
NF4Tensor,
to_nf4,
_INNER_TENSOR_NAMES_FOR_SHARDING,
)
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao
from typing import Tuple, Union


bnb_available = False
Expand Down Expand Up @@ -234,8 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)



@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parametrize("shape", [(16, 16), (32, 16)])
Expand All @@ -250,7 +256,184 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data)



class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
nf4_tensor = to_nf4(torch.randn(input_size))
chunks = list(torch.chunk(nf4_tensor, num_chunks))
self.assertEqual(len(chunks), num_chunks)
if isinstance(input_size, int):
expected_size0 = input_size // num_chunks
else:
expected_size0 = input_size[0] // num_chunks
for chunk in chunks:
self.assertEqual(chunk.size(0), expected_size0)

@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)])
def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [(512, 512, 512)])
def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size)
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor_zeros, attr)
self.assertEqual(torch.count_nonzero(inner_tensor), 0)
expected_size = input_size if not isinstance(input_size, int) else (input_size, )
self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size))

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
if isinstance(input_size, int):
new_size = input_size + 1
elif len(input_size) == 1:
new_size = (input_size[0] + 1, )
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
orig_attrs, _ = nf4_tensor.__tensor_flatten__()
orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs])
end_idx = input_size if isinstance(input_size, int) else input_size[0]
sliced_tensor = nf4_tensor[:end_idx]
self.assertEqual(nf4_tensor.size(), sliced_tensor.size())
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
sliced_tensor_inner = getattr(sliced_tensor, attr)
self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage)
self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr])

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with customized step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
viewed_tensor = nf4_tensor.view(-1)
self.assertEqual(viewed_tensor.dim(), 1)
self.assertEqual(viewed_tensor.numel(), math.prod(input_size))
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
if len(input_size) == 1:
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"):
nf4_tensor.view(input_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset())
self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size())
self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride())
self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset())
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor_orig = getattr(nf4_tensor, attr)
inner_tensor_strided = getattr(nf4_tensor_strided, attr)
self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size())
self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride())
self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset())


@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
if len(input_size) == 1:
size = (input_size[0] - 1, )
else:
size = (input_size[0] - 1, input_size[1])
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"):
torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset())
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"):
torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1)

if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"):
stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0])
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset())

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_pin_memory(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertFalse(nf4_tensor.is_pinned())

nf4_tensor = nf4_tensor.pin_memory()
self.assertTrue(nf4_tensor.is_pinned())

nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
self.assertFalse(nf4_tensor.is_pinned())


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_cuda(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", non_blocking=True)
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda")
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16)
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(nf4_tensor.dtype, torch.bfloat16)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_cpu(self):
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
nf4_tensor = nf4_tensor.cpu()
self.assertEqual(nf4_tensor.device.type, "cpu")
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)

if __name__ == "__main__":
run_tests()
Loading

0 comments on commit a049baf

Please sign in to comment.