-
Notifications
You must be signed in to change notification settings - Fork 631
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
98cbc4b
commit 2f2063b
Showing
2 changed files
with
74 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) | |
|
||
|
||
|
||
def create_dynamic_map(signed=True, n=7): | ||
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): | ||
""" | ||
Creates the dynamic quantiztion map. | ||
|
@@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7): | |
# these are additional items that come from the case | ||
# where all the exponent bits are zero and no | ||
# indicator bit is present | ||
additional_items = 2 ** (7 - n) - 1 | ||
non_sign_bits = total_bits - (1 if signed else 0) | ||
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 | ||
if not signed: | ||
additional_items = 2 * additional_items | ||
for i in range(n): | ||
fraction_items = ( | ||
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 | ||
) | ||
for i in range(max_exponent_bits): | ||
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) | ||
boundaries = torch.linspace(0.1, 1, fraction_items) | ||
means = (boundaries[:-1] + boundaries[1:]) / 2.0 | ||
data += ((10 ** (-(n - 1) + i)) * means).tolist() | ||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() | ||
if signed: | ||
data += (-(10 ** (-(n - 1) + i)) * means).tolist() | ||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() | ||
|
||
if additional_items > 0: | ||
boundaries = torch.linspace(0.1, 1, additional_items + 1) | ||
means = (boundaries[:-1] + boundaries[1:]) / 2.0 | ||
data += ((10 ** (-(n - 1) + i)) * means).tolist() | ||
if signed: | ||
data += (-(10 ** (-(n - 1) + i)) * means).tolist() | ||
if additional_items > 0: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
TimDettmers
Author
Collaborator
|
||
boundaries = torch.linspace(0.1, 1, additional_items + 1) | ||
means = (boundaries[:-1] + boundaries[1:]) / 2.0 | ||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() | ||
if signed: | ||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() | ||
|
||
data.append(0) | ||
data.append(1.0) | ||
|
||
gap = 256 - len(data) | ||
for i in range(gap): | ||
data.append(0) | ||
|
||
data.sort() | ||
return Tensor(data) | ||
|
||
|
@@ -371,9 +375,7 @@ def nvidia_transform( | |
return out, new_state | ||
|
||
|
||
def estimate_quantiles( | ||
A: Tensor, out: Tensor = None, offset: float = 1 / 512 | ||
) -> Tensor: | ||
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: | ||
''' | ||
Estimates 256 equidistant quantiles on the input tensor eCDF. | ||
|
@@ -393,25 +395,36 @@ def estimate_quantiles( | |
out : torch.Tensor | ||
Tensor with the 256 estimated quantiles. | ||
offset : float | ||
The offset for the first and last quantile from 0 and 1. Default: 1/512 | ||
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) | ||
num_quantiles : int | ||
The number of equally spaced quantiles. | ||
Returns | ||
------- | ||
torch.Tensor: | ||
The 256 quantiles in float32 datatype. | ||
''' | ||
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') | ||
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") | ||
if num_quantiles < 256 and offset == 1/(512): | ||
# override default arguments | ||
offset = 1/(2*num_quantiles) | ||
|
||
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) | ||
is_on_gpu([A, out]) | ||
device = pre_call(A.device) | ||
if A.dtype == torch.float32: | ||
lib.cestimate_quantiles_fp32( | ||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) | ||
) | ||
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) | ||
elif A.dtype == torch.float16: | ||
lib.cestimate_quantiles_fp16( | ||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) | ||
) | ||
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) | ||
else: | ||
raise NotImplementedError(f"Not supported data type {A.dtype}") | ||
post_call(device) | ||
|
||
if num_quantiles < 256: | ||
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) | ||
out = out[idx] | ||
|
||
return out | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The indent here looks clearly wrong, this repeats the additional_items for each value of exponent bits while it should apply to only the last value.