From c1d9de3be3a4cdf70b367b2cf8d30e63fbe767ec Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 22 Apr 2024 16:40:57 -0700 Subject: [PATCH 1/2] Bug Fix: NF4 .to('cuda') --- test/dtypes/test_nf4.py | 10 ++++++++++ torchao/dtypes/nf4tensor.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index e3b25e3c3..7142bff62 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -192,6 +192,16 @@ def test_to_copy(self, dtype: torch.dtype): nf4_to_dtype = inpt_tensor_nf4.to(dtype) torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) + @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") + def test_to_copy_device(self): + inpt_tensor = torch.rand(128, device='cpu') + t = to_nf4(inpt_tensor, 32, 2) + assert t.device == torch.device('cpu') + z = t.cuda() + assert z.device.type == "cuda" # Because the device could be cuda:0 + x = z.cpu() + assert x.device == torch.device('cpu') + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_to_dtype(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, dtype=dtype) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 886eb6c0a..f09d53821 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -47,7 +47,7 @@ def _to_copy(func, *args, **kwargs): if not args[0][0].is_contiguous(): assert args[0][0].t().is_contiguous() return func(args[0][0].t()).t() - return args[0][0].get_original_weight().to(args[1]["dtype"]) + return args[0][0].get_original_weight().to(args[1]["dtype"]).to(args[1]["device"]) @implements([torch.ops.aten.to.dtype]) From 0dfcbfd7de6b7880956799978bf45f2314a58e9e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 22 Apr 2024 16:49:18 -0700 Subject: [PATCH 2/2] more test --- test/dtypes/test_nf4.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 7142bff62..55bbe0bcb 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -202,6 +202,10 @@ def test_to_copy_device(self): x = z.cpu() assert x.device == torch.device('cpu') + inpt_tensor = torch.rand(128, device='cuda') + t = to_nf4(inpt_tensor, 32, 2) + assert t.device.type == "cuda" + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_to_dtype(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, dtype=dtype)