Skip to content

Commit

Permalink
fix fp8 error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed Oct 28, 2024
1 parent 813b64b commit 5dfb5e5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
warmup_mode,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.torch_utils import is_fp8_kv
from lorax_server.utils.torch_utils import is_fp8_kv, is_fp8_supported, is_fp8
from lorax_server.utils.weights import Weights

ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1"))
Expand Down Expand Up @@ -957,6 +957,9 @@ def __init__(
config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
config.quantize = quantize

if is_fp8(config.quantize) and not is_fp8_supported():
raise ValueError('FP8 quantization is only supported on hardware that supports FP8')

if is_fp8_kv(config.quantize):
self.kv_dtype = torch.float8_e4m3fn
logger.info('Enabling FP8 KV Cache')
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def is_fp8_supported():


def is_fp8_kv(quantize):
return is_fp8_supported() and quantize and quantize == 'fp8_kv'
return quantize and quantize == 'fp8_kv'


def is_fp8(quantize):
return is_fp8_supported() and quantize and quantize.startswith('fp8')
return quantize and quantize.startswith('fp8')

0 comments on commit 5dfb5e5

Please sign in to comment.