diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 837f6bf1c..96f8ce4e6 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -322,10 +322,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - non_sign_bits = total_bits - (1 if signed else 0) + non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 - if not signed: - additional_items = 2 * additional_items for i in range(max_exponent_bits): fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) boundaries = torch.linspace(0.1, 1, fraction_items) @@ -334,16 +332,18 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): if signed: data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data.append(0) data.append(1.0) + assert len(data) == 2**total_bits + gap = 256 - len(data) for i in range(gap): data.append(0) diff --git a/tests/test_functional.py b/tests/test_functional.py index d7212b047..f825c14df 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -129,6 +129,7 @@ def test_quantile_quantization(): assert diff < 0.001 + def test_dynamic_quantization(): diffs = [] reldiffs = [] @@ -141,8 +142,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - # print(sum(diffs)/len(diffs)) - # print(sum(reldiffs)/len(reldiffs)) + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") @@ -157,7 +158,8 @@ def test_dynamic_quantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) -def test_dynamic_blockwise_quantization(dtype, nested, blocksize): +@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False']) +def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): #print('') diffs = [] reldiffs = [] @@ -178,9 +180,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize): assert A2.dtype == dtype diffs = [] + code = F.create_dynamic_map(signed=signed) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) - C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() reldiff = diff / torch.abs(A1.float() + 1e-8) @@ -189,11 +192,15 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize): #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) abserr = sum(diffs)/len(diffs) relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.0035 - assert relerr < 0.015 + if signed: + assert abserr < 0.0035 + assert relerr < 0.015 + else: + assert abserr < 0.00175 + assert relerr < 0.012 assert A2.dtype == dtype - #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))