Skip to content

Commit

Permalink
Add page context fmha (NVIDIA#9526)
Browse files Browse the repository at this point in the history
  • Loading branch information
meatybobby authored Jun 25, 2024
1 parent 26aef8e commit 8c6b407
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def export(
use_embedding_sharing: bool = False,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
dtype: str = "bfloat16",
load_model: bool = True,
enable_multi_block_mode: bool = False,
Expand Down Expand Up @@ -162,6 +163,7 @@ def export(
use_parallel_embedding (bool): whether to use parallel embedding feature of TRT-LLM or not
use_embedding_sharing (bool):
paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM.
paged_context_fmha (bool): whether to use paged context fmha feature of TRT-LLM or not
remove_input_padding (bool): enables removing input padding or not.
dtype (str): Floating point type for model weights (Supports BFloat16/Float16).
load_model (bool): load TensorRT-LLM model after the export.
Expand Down Expand Up @@ -295,6 +297,7 @@ def export(
enable_multi_block_mode=enable_multi_block_mode,
paged_kv_cache=paged_kv_cache,
remove_input_padding=remove_input_padding,
paged_context_fmha=paged_context_fmha,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
)
Expand Down
2 changes: 2 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def build_and_save_engine(
enable_multi_block_mode: bool = False,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
max_num_tokens: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
Expand All @@ -65,6 +66,7 @@ def build_and_save_engine(
else:
plugin_config.paged_kv_cache = False
plugin_config.remove_input_padding = remove_input_padding
plugin_config.use_paged_context_fmha = paged_context_fmha

max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
Expand Down

0 comments on commit 8c6b407

Please sign in to comment.