Skip to content

Commit

Permalink
Test C++ runtime on demand in nemo_export.py to avoid possible OOMs (N…
Browse files Browse the repository at this point in the history
…VIDIA#9544)

* Add test_cpp_runtime flag

Signed-off-by: Jan Lasek <[email protected]>

* Apply isort and black reformatting

Signed-off-by: janekl <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Co-authored-by: janekl <[email protected]>
Signed-off-by: tonyjie <[email protected]>
  • Loading branch information
2 people authored and tonyjie committed Aug 6, 2024
1 parent fce371a commit 1ac4950
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def run_inference(
debug=True,
streaming=False,
stop_words_list=None,
test_cpp_runtime=False,
test_deployment=False,
test_data_path=None,
save_trt_engine=False,
Expand Down Expand Up @@ -316,12 +317,21 @@ def run_inference(
LOGGER.warning("Model outputs don't match the expected result.")
functional_result.regular_pass = False

if not use_lora_plugin and not ptuning and not use_vllm:
test_cpp_runtime(
engine_path=model_dir,
prompt=prompts,
output_cpp = ""
if test_cpp_runtime and not use_lora_plugin and not ptuning and not use_vllm:
# This may cause OOM for large models as it creates 2nd instance of a model
exporter_cpp = TensorRTLLM(
model_dir,
load_model=True,
use_python_runtime=False,
)

output_cpp = exporter_cpp.forward(
input_texts=prompts,
max_output_len=max_output_len,
debug=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)

nq = None
Expand Down Expand Up @@ -365,6 +375,9 @@ def run_inference(
print("")
print("--- Output deployed: ", output_deployed)
print("")
print("")
print("--- Output with C++ runtime: ", output_cpp)
print("")

accuracy_result = None
if run_accuracy:
Expand All @@ -382,27 +395,6 @@ def run_inference(
raise Exception("Checkpoint {0} could not be found.".format(checkpoint_path))


def test_cpp_runtime(
engine_path,
prompt,
max_output_len,
debug,
):
trt_llm_exporter = TensorRTLLM(engine_path, load_model=True)
output = trt_llm_exporter.forward(
input_texts=prompt,
max_output_len=max_output_len,
top_k=1,
top_p=0.0,
temperature=1.0,
)

if debug:
print("")
print("--- Output deployed with cpp runtime: ", output)
print("")


def run_existing_checkpoints(
model_name,
use_vllm,
Expand All @@ -413,6 +405,7 @@ def run_existing_checkpoints(
lora=False,
streaming=False,
run_accuracy=False,
test_cpp_runtime=False,
test_deployment=False,
stop_words_list=None,
test_data_path=None,
Expand Down Expand Up @@ -477,6 +470,7 @@ def run_existing_checkpoints(
debug=True,
streaming=streaming,
stop_words_list=stop_words_list,
test_cpp_runtime=test_cpp_runtime,
test_deployment=test_deployment,
test_data_path=test_data_path,
save_trt_engine=save_trt_engine,
Expand Down Expand Up @@ -588,6 +582,11 @@ def get_args():
default="False",
)
parser.add_argument("--streaming", default=False, action="store_true")
parser.add_argument(
"--test_cpp_runtime",
type=str,
default="False",
)
parser.add_argument(
"--test_deployment",
type=str,
Expand Down Expand Up @@ -630,6 +629,7 @@ def str_to_bool(name: str, s: str) -> bool:
return False
raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'")

args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime)
args.test_deployment = str_to_bool("test_deployment", args.test_deployment)
args.save_trt_engine = str_to_bool("save_trt_engin", args.save_trt_engine)
args.run_accuracy = str_to_bool("run_accuracy", args.run_accuracy)
Expand Down Expand Up @@ -672,6 +672,7 @@ def run_inference_tests(args):
pp_size=args.pp_size,
streaming=args.streaming,
test_deployment=args.test_deployment,
test_cpp_runtime=args.test_cpp_runtime,
run_accuracy=args.run_accuracy,
test_data_path=args.test_data_path,
save_trt_engine=args.save_trt_engine,
Expand Down Expand Up @@ -714,6 +715,7 @@ def run_inference_tests(args):
debug=args.debug,
streaming=args.streaming,
test_deployment=args.test_deployment,
test_cpp_runtime=args.test_cpp_runtime,
test_data_path=args.test_data_path,
save_trt_engine=args.save_trt_engine,
)
Expand Down

0 comments on commit 1ac4950

Please sign in to comment.