-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rfcs: support int4/int8 compression for K/V in fuesd SDPA
- Loading branch information
Showing
3 changed files
with
326 additions
and
0 deletions.
There are no files selected for viewing
326 changes: 326 additions & 0 deletions
326
rfcs/20240808-graph-api-int-compression-for-sdpa/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,326 @@ | ||
# Graph API: Support int4/int8 compression for K/V in fused SDPA | ||
|
||
## Introduction | ||
|
||
Int4 and int8 compression for Key and Value are exploited in fused SDPA to | ||
reduce the memory footprint of generative inference of LLM, especially when KV | ||
cache mechanism is adopted. In more details, Key and Value tensors will be | ||
stored in lower integral data types such as int4 and int8 for less memory | ||
consumption, and will be de-quantized to wider floating point data types such as | ||
f16 and bf16 for computation. According to the investigation and user request, | ||
it's also possible that Key tensor will be stored in int4 while Value tensor in | ||
int8 as it's more sensitive to precision. | ||
|
||
Currently, oneDNN Graph API already supports int8 data types and | ||
per-tensor/channel int8 quantization, but does not support int4 data types. In | ||
addition, due to the limited range of representation of int4 data types, grouped | ||
scales and zero points are required for int4 quantization and dequantization, | ||
which are not supported in oneDNN Graph either. With that, this document will | ||
concentrate more on supporting int4 data types and its quantization in oneDNN | ||
Graph in the following parts. | ||
|
||
To be more specific, in this RFC we are going to discuss about supporting the | ||
following features in oneDNN Graph API: | ||
|
||
1. Support u4/s4 data types in Graph API. | ||
2. Support int4 data types quantization and dequantization. | ||
3. Extend the existing SDPA pattern with compressed K/V inputs. | ||
|
||
## Proposals | ||
|
||
### Proposal 1: Support u4/s4 data types | ||
|
||
The proposal is to follow oneDNN primitive APIs design with adding int4 data | ||
types enumerations in `dnnl_data_type_t` (defined in dnnl_common_types.h). The | ||
new enumerations will work for both primitive and graph C APIs. | ||
|
||
```cpp | ||
/// Data type specification | ||
typedef enum { | ||
/// Undefined data type, used for empty memory descriptors. | ||
dnnl_data_type_undef = 0, | ||
/// 16-bit/half-precision floating point. | ||
dnnl_f16 = 1, | ||
/// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. | ||
dnnl_bf16 = 2, | ||
/// 32-bit/single-precision floating point. | ||
dnnl_f32 = 3, | ||
/// 32-bit signed integer. | ||
dnnl_s32 = 4, | ||
/// 8-bit signed integer. | ||
dnnl_s8 = 5, | ||
/// 8-bit unsigned integer. | ||
dnnl_u8 = 6, | ||
/// 64-bit/double-precision floating point. | ||
dnnl_f64 = 7, | ||
/// Boolean data type. Size is C++ implementation defined. | ||
dnnl_boolean = 8, | ||
/// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) | ||
/// with a 5-bit exponent and a 2-bit mantissa. | ||
dnnl_f8_e5m2 = 9, | ||
/// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) | ||
/// with a 4-bit exponent and a 3-bit mantissa. | ||
dnnl_f8_e4m3 = 10, | ||
/// 4-bit signed integer. | ||
dnnl_s4 = 11, | ||
/// 4-bit unsigned integer. | ||
dnnl_u4 = 12, | ||
|
||
/// Parameter to allow internal only data_types without undefined behavior. | ||
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2. | ||
dnnl_data_type_max = 0x7fff, | ||
} dnnl_data_type_t; | ||
``` | ||
|
||
In Graph C++ API, the enum class `logical_tensor::data_type` (defined in | ||
dnnl_graph.hpp) will be extended to support the int4 data types. | ||
|
||
```c++ | ||
class logical_tensor { | ||
// ... | ||
/// Data Types | ||
enum class data_type { | ||
undef = dnnl_data_type_undef, | ||
/// 16-bit/half-precision floating point. | ||
f16 = dnnl_f16, | ||
/// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. | ||
bf16 = dnnl_bf16, | ||
/// 32-bit/single-precision floating point. | ||
f32 = dnnl_f32, | ||
/// 32-bit signed integer. | ||
s32 = dnnl_s32, | ||
/// 8-bit signed integer. | ||
s8 = dnnl_s8, | ||
/// 8-bit unsigned integer. | ||
u8 = dnnl_u8, | ||
/// Boolean data type. Size is C++ implementation defined. | ||
boolean = dnnl_boolean, | ||
/// 8-bit floating point data type with E5M2. Added by this RFC. | ||
f8_e5m2 = dnnl_f8_e5m2, | ||
/// 8-bit floating point data type with E4M3. Added by this RFC. | ||
f8_e4m3 = dnnl_f8_e4m3, | ||
/// 4-bit signed integer. | ||
s4 = dnnl_s4, | ||
/// 4-bit unsigned integer. | ||
u4 = dnnl_u4, | ||
}; | ||
// ... | ||
}; | ||
``` | ||
The new data types can be used to create logical tensors and tensors in oneDNN | ||
Graph API. | ||
```c++ | ||
using namespace dnnl::graph; | ||
// create a s4 logical tensor with unknown ndim and dims. | ||
auto lt1 = logical_tensor(id, logical_tensor::data_type::s4); | ||
// create a u4 logical tensor with dims = (16, 8). | ||
const logical_tensor::dims shape = {16, 8} | ||
auto lt2 = logical_tensor(id, logical_tensor::data_type::u4, shape); | ||
``` | ||
|
||
### Proposal 2: Extend Quantize and Dequantize Operations | ||
|
||
The recommended proposal is that oneDNN Graph API will support int4 data type | ||
through graph fusion following the way that int8 is supported, instead of | ||
directly adding int4 datatype in computation operations. | ||
|
||
Based on the request, int4 quantization happens in runtime (scales and zps are | ||
stored as tensors on device). Graph API will support int4 quantization through | ||
operations `DynamicDequantize` and `DynamicQuantize`. In the future, it's | ||
possible that we will also extend the static version of `Quantize` and | ||
`Dequantize` to support int4 quantization. But for now, we wil focus on | ||
`DynamicQuantize` and `DynamicDequantize` due to the request urgency and the | ||
feature use scenario. | ||
|
||
According to the request from frameworks, for int4 quantization, grouped scales | ||
are required for each hidden dimension to maintain the model accuracy. The | ||
requirement for grouped zero points is not promised. See [int4 | ||
Quantize](https://docs.nvidia.com/deeplearning/tensorrt/operators/docs/Quantize.html?highlight=int4) | ||
and | ||
[Dequantize](https://docs.nvidia.com/deeplearning/tensorrt/operators/docs/Dequantize.html) | ||
in TensorRT and | ||
[QuantizeLinear](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html) in | ||
ONNX. | ||
|
||
Currently, `DynamicDequantize` and `DynamicQuantize` in Graph API accept a | ||
required f32 1D tensor as the scale factor, and an optional 1D tensor as the | ||
zero points( can be either s8/u8/f32 ). What's more, two optional attributes are | ||
provided: `qtype` is used to specify which quantization type is used, | ||
`per_tensor` or `per_channel`. And `axis` specifies on which dimension | ||
`per_channel` quantization will be applied. To support int4 data types, | ||
`DynamicQuantize` and `DynamicDequantize` operations in oneDNN Graph opset will | ||
be extended: | ||
|
||
1. `DynamicQuantize` will be extended to support u4/s4 output, and scales and | ||
zero points with groups. The library may raise an error if groups are not | ||
provided for a `DynamicQuantize` op with u4/s4 output. | ||
2. `DynamicDequantize` will be extended to support u4/s4 input, and scales and | ||
zero points with groups. The library may raise an error if groups are not | ||
provided for a `DynamicDequantize` op with u4/s4 input. | ||
|
||
Besides the input or output data types, new attributes will be added for | ||
DynamicQuantize and DynamicDequantize: | ||
|
||
1. Add `per_group` to the supported values of `qtype` attribute, and the default value | ||
will be unchanged. | ||
2. Add a new optional attribute `group_size` to support `per_group` quantization | ||
type. The value of `group_size` attribute is a single `s64` data, which | ||
indicates the number of elements that share the same scaling factor and zero | ||
points in each quantization group. The attribute is only required when | ||
`per_group` quantization type is specified for `qtype`. If `per_group` | ||
quantization is not specified and `group_size` attribute is given, it will | ||
be ignored. | ||
3. The `scale` and `zp` input will be extended to a `n`-dimensional tensor | ||
(`n` = dims(input)) for `per_group` quantization. The shape requirements for | ||
`scale` and `zp` are as follows: | ||
1. For `per_tensor` and `per_channel` quantization, the `scale` and `zp` | ||
inputs should be `1d` tensors. | ||
2. For `per_group` quantization: | ||
- One of the last two dimensions should match the input, and the other | ||
should equal to `src_dim / group_size`. | ||
- On other dimensions, the dimension should be either 1 or match the | ||
input. | ||
|
||
For example, if the shape of Key tensor is `4x8x4`, and the group size is `4`, | ||
which means each scaling factor will be adopted for 4 times. One possible scale | ||
would be: | ||
|
||
![alt text](img/grouped_scale.png) | ||
|
||
The usage of the new attribute will be like: | ||
|
||
```cpp | ||
using namespace dnnl::graph; | ||
|
||
const dim K = 10, N = 20; | ||
const size_t nG = 2, G = N / nG; | ||
|
||
dims src_dims = {K, N}; | ||
dims scale_dims = {K, N/G}; | ||
|
||
// In real case, users do not have to provide the concrete shapes of logical | ||
// tensors until compilation stage. | ||
logical_tensor src_desc {0, data_type::u4, src_dims, layout_type::strided}; | ||
logical_tensor scale_desc { | ||
1, data_type::f32, scale_dims, layout_type::strided}; | ||
logical_tensor dst_desc {2, data_type::f32, src_dims, layout_type::strided}; | ||
|
||
op deq = op(0, op::kind::DynamicDequantize, {src_desc, scale_desc}, | ||
{dst_desc}, "deq"); | ||
// Specify the quantization type as per_group quantization. | ||
deq.set_attr<std::string>(op::attr::qtype, "per_group"); | ||
// Group size indicates the the size of each group. | ||
deq.set_attr<int64_t>(op::attr::group_size, G); | ||
// Axis indicates on which dimension groups of scales will be shared. | ||
deq.set_attr<int64_t>(op::attr:axis, 1); | ||
|
||
// Create graph and add the op to the graph | ||
graph g(dnnl::engine::kind::gpu); | ||
g.add_op(deq); | ||
|
||
// ... | ||
``` | ||
#### Alternative options | ||
Similar to the discussion of int8 and fp8 data types, an alternative solution | ||
may be supporting int4 data types directly in computation operations like MatMul | ||
and Convolution. But it will bloat and complicate the opset and op schema of | ||
oneDNN Graph. Hence it's not considered here. | ||
In addition, currrently oneDNN Graph does not support the direct int8 | ||
quantization and dequantization between int8 and f16/bf16 data types. To align | ||
with current integration flow and maintain the consistency of library design, | ||
quantization and dequantization between int4 and f16/bf16 data types will also | ||
be implemented with a combination of quantization ops and `TypeCast`. An | ||
alternative solution may be to extend quantization ops to support data type | ||
conversion directly between int4 and f16/bf16 data types. But it will | ||
unnecessarily complicates the fusion patterns in oneDNN Graph. For example, | ||
an int4 quantization will be expressed in two different ways: | ||
1. With a single `Quantize(f16 -> int4)` | ||
2. With two operations `TypeCast(f16 -> f32) -> Quantize(f32 -> int4)` | ||
It will require extensions for all patterns related to mixed data type | ||
quantization. Hence the option is not considered here. | ||
### Proposal 3: Extend SDPA pattern with compressed K/V | ||
#### Pattern definition | ||
With the reference to int4 GQA pattern in | ||
[PyTorch](https://pytorch.org/blog/int4-decoding/), the proposed fusion pattern | ||
for SDPA with compressed K/V in oneDNN Graph API is defined as: | ||
![alt text](img/sdpa_pattern.png) | ||
As illustrated in the figure above, K cache and V cache are stored on user side | ||
in int4 or int8 data types. `DynamicDequantize` is used to convert int4 data to | ||
f32, then an optional `TypeCast` is used to convert f32 data types to f16 or | ||
bf16 data types when needed. Two `MatMul`s are used to compute the matrix | ||
multiplication of Q & K, and score & V. oneDNN Graph will fuse the pattern into | ||
a single partition with optimized CPU and GPU implementations. | ||
To get rid of the ambiguity of the computation type of the MatMul ops in the | ||
pattern and give users the control of leveraging higher precision data types to | ||
keep the model accuracy, we propose to follow the design of floating point math | ||
mode in primitive API. Once the math mode is specified, the backend will | ||
automatically perform either implicit down-conversion to lower-precision | ||
integral type values or up-conversion from int4/int8 to floating points, such as | ||
f16, based on user input. For compressed SDPA, since the Key tensor will be in | ||
integral types like int4 and `MatMul`s are expected to compute in f16/bf16, | ||
users must explicitly set the math mode. Put it in another way, the API allows | ||
users to control the computation type of the SDPA fusion. What's more, following | ||
the design of floating point math mode also helps to maintain a consistent API | ||
semantics between graph API and primitive API. | ||
Currently, oneDNN Graph API supports setting floating-point math mode during the | ||
construction of graph object, which will affect the whole graph. | ||
As primitive API has supported a second argument for enforcing an integral | ||
primitive to comply with the floating-point math mode, oneDNN Graph API need to | ||
accommodate to the change. The new API will be like: | ||
```cpp | ||
/// @param mode Floating-point math mode. | ||
/// @param apply_to_int Use of floating-point arithmetic for integer primitives. | ||
void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false); | ||
``` | ||
|
||
Users can use the new API like: | ||
|
||
```cpp | ||
using namespace dnnl::graph; | ||
|
||
graph g(kind); | ||
g.set_fpmath_mode( math_mode, /*apply_to_int=*/true); | ||
|
||
op foo(id, kind, "foo"); g.add_op(foo); | ||
op bar(id, kind, "bar"); g.add_op(bar); | ||
g.finalize(); | ||
partitions p = g.get_partitions(); | ||
|
||
// All the partitions from the same graph should share the same math mode | ||
compiled_partition cp0 = p[0].compile(inputs, outputs, engine); | ||
compiled_partition cp1 = p[1].compile(inputs, outputs, engine); | ||
cp0.execute(…); | ||
cp1.execute(…); | ||
``` | ||
User can still set graph-level fpmath mode through graph constructor, and only | ||
one method will be allowed for setting the graph-level fpmath mode. Otherwise, | ||
the library will report an error. | ||
#### Alternative option | ||
Considering the fact that the framework users can create different graphs if | ||
they require different floating point math mode, the current API capability is | ||
enough for covering the user request. An alternative solution is to specify the | ||
floating point math mode on partition. But it may require new APIs and may cause | ||
confusion when the modes on graph and partition are incompatible. Hence it's not | ||
considered here. |
Binary file added
BIN
+90.8 KB
rfcs/20240808-graph-api-int-compression-for-sdpa/img/grouped_scale.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.