Skip to content

Commit

Permalink
Added better default compute_dtype handling for Linear4bit layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Jul 22, 2023
1 parent c82f51c commit 412fd0e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
27 changes: 27 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload

import warnings
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
Expand Down Expand Up @@ -205,6 +206,28 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
self.compute_type_is_set = False

def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')






def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
Expand All @@ -213,6 +236,10 @@ def forward(self, x: torch.Tensor):

if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True

inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
Expand Down
39 changes: 33 additions & 6 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,10 @@ def test_linear_kbit_fp32_bias(module):
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", modules, ids=names)
def test_kbit_backprop(module):
Expand Down Expand Up @@ -563,10 +566,10 @@ def test_kbit_backprop(module):
relerrs2.append(relerr2.mean().item())

if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
Expand Down Expand Up @@ -608,9 +611,33 @@ def test_fp8linear():
assert graderr < 0.00002
assert bgraderr < 0.00002




def test_4bit_warnings():
dim1 = 64

with pytest.warns(UserWarning, match=r'inference or training'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning, match=r'inference.'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)

with pytest.warns(UserWarning) as record:

net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)

net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)

assert len(record) == 2



0 comments on commit 412fd0e

Please sign in to comment.