Skip to content

Commit

Permalink
Support FP8 KV Cache (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Oct 29, 2024
1 parent 373c3e6 commit 2ff1c71
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 39 deletions.
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum Quantization {
Hqq_3bit,
Hqq_2bit,
Fp8,
Fp8_KV,
}

impl std::fmt::Display for Quantization {
Expand Down Expand Up @@ -68,6 +69,9 @@ impl std::fmt::Display for Quantization {
Quantization::Fp8 => {
write!(f, "fp8")
}
Quantization::Fp8_KV => {
write!(f, "fp8-kv")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Quantization(str, Enum):
hqq_3bit = "hqq-3bit"
hqq_2bit = "hqq-2bit"
fp8 = "fp8"
fp8_kv = "fp8-kv"


class Dtype(str, Enum):
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import functional as F

from lorax_server.utils.import_utils import SYSTEM
from lorax_server.utils.torch_utils import is_fp8

if SYSTEM == "rocm":
try:
Expand Down Expand Up @@ -95,9 +96,10 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None,
if fan_in_fan_out:
weight = weight.T.contiguous()

if quantize is None or (quantize == "fp8" and weight_scale is None):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "fp8":

elif is_fp8(quantize):
from lorax_server.layers.fp8 import Fp8Linear

linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale)
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load(config, prefix: str, weights):
should_gather = False

# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "fp8"]:
if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8-kv"]:
quantize = None
else:
quantize = config.quantize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class Gemma2Config(PretrainedConfig):
Expand Down Expand Up @@ -170,7 +171,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq", "marlin"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.head_dim
Expand Down Expand Up @@ -212,6 +213,14 @@ def __init__(self, layer_id: int, prefix: str, config, weights, causal: bool, is
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -257,7 +266,16 @@ def forward(

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -273,6 +291,9 @@ def forward(
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -286,6 +307,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
27 changes: 25 additions & 2 deletions server/lorax_server/models/custom_modeling/flash_gemma_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class GemmaConfig(PretrainedConfig):
Expand Down Expand Up @@ -153,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.head_dim
Expand Down Expand Up @@ -197,6 +198,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -264,7 +273,16 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -278,6 +296,9 @@ def forward(
cu_seqlen_prefill,
max_s,
self.softmax_scale,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -291,6 +312,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
27 changes: 25 additions & 2 deletions server/lorax_server/models/custom_modeling/flash_llama_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class LlamaConfig(PretrainedConfig):
Expand Down Expand Up @@ -200,7 +201,7 @@ def _load_gqa(config, prefix: str, weights):
if isinstance(weight, tuple):
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -252,6 +253,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.kv_dtype = 'fp8'
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.kv_dtype = 'auto'

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -319,7 +328,16 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -333,6 +351,9 @@ def forward(
cu_seqlen_prefill,
max_s,
self.softmax_scale,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -347,6 +368,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized

if not HAS_FLASH_ATTN_V2_CUDA:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -205,7 +206,7 @@ def _load_gqa(config, prefix: str, weights, head_size):
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

num_heads = config.num_attention_heads // weights.process_group.size()
Expand Down Expand Up @@ -260,6 +261,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size)

Expand Down Expand Up @@ -333,7 +342,16 @@ def forward(
else:
kv_to_cache = kv

paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv_to_cache[:, 0],
kv_to_cache[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -348,6 +366,9 @@ def forward(
max_s,
self.softmax_scale,
window_size_left=self.max_past,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -361,6 +382,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
Loading

0 comments on commit 2ff1c71

Please sign in to comment.