Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update infer and execute API to take prompts from txt file for BS>=1 #11

Merged
merged 29 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
042f2c1
[QEff]: Update infer and execute API to take prompts from txt file fo…
quic-mamta May 16, 2024
0802373
Update infer and execute API
quic-mamta May 17, 2024
bc5ca88
Update infer and execute API
quic-mamta May 20, 2024
8712b87
Update README.md
quic-mamta May 21, 2024
81a3163
Update README.md
quic-mamta May 21, 2024
968dd41
Update README.md
quic-mamta May 21, 2024
0229664
Update infer, execute and text generation interface
quic-mamta May 21, 2024
0a40c9e
Merge branch 'main' into add_inputs_txt_file
quic-mamta May 21, 2024
e51431f
Update execute.py
quic-mamta May 21, 2024
18c973c
Update execute.py
quic-mamta May 21, 2024
cef24ab
Update text generation interface
quic-mamta May 21, 2024
b6920c4
Update Notebooks
quic-mamta May 21, 2024
20cdb52
Update README.md
quic-mamta May 21, 2024
80fb101
Update README.md
quic-mamta May 21, 2024
01999ca
Update text_generation_inference.py
quic-mamta May 21, 2024
94b7ead
Update infer and execute and text generation interface
quic-mamta May 22, 2024
885c07b
Update infer.py
quic-mamta May 22, 2024
bc615b4
Update README.md
quic-mamta May 22, 2024
6303154
Update README.md
quic-mamta May 22, 2024
52e74cb
Update README.md
quic-mamta May 22, 2024
7498451
Update infer.py
quic-mamta May 22, 2024
a6b0480
Update execute.py
quic-mamta May 22, 2024
be88571
Update files
quic-mamta May 22, 2024
0711073
Update files
quic-mamta May 22, 2024
5449fbb
Update README.md
quic-mamta May 22, 2024
17096a3
Update QEfficientGPT2.ipynb
quic-mamta May 22, 2024
107b414
Update QEfficientMPT.ipynb
quic-mamta May 22, 2024
0e567fa
Update README.md
quic-mamta May 22, 2024
ade2c13
Update README.md
quic-mamta May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions QEfficient/cloud/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
from huggingface_hub import login
from transformers import AutoTokenizer

from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.generation.text_generation_inference import (
check_batch_size_and_num_prompts,
cloud_ai_100_exec_kv,
get_compilation_batch_size,
read_prompts_txt_file,
)
from QEfficient.utils import hf_download
from QEfficient.utils.constants import Constants


def main(
model_name: str,
prompt: str,
prompts_txt_file_path: str,
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
qpc_path: str,
devices: List[int],
cache_dir: str = Constants.CACHE_DIR,
Expand All @@ -34,11 +40,29 @@ def main(
"""
if hf_token is not None:
login(hf_token)

# Download tokenizer along with model if it doesn't exist
model_hf_path = hf_download(repo_id=model_name, cache_dir=cache_dir, allow_patterns=["*.json"])
tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left")

cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt)
assert (prompt is None and prompts_txt_file_path is not None) or (
prompt is not None and prompts_txt_file_path is None
), "Please pass either single input string using --prompt or multiple inputs using --prompts_txt_file_path"

if prompts_txt_file_path is not None:
prompt = read_prompts_txt_file(prompts_txt_file_path)

compilation_batch_size = get_compilation_batch_size(qpc_path)
check_batch_size_and_num_prompts(prompt, compilation_batch_size)

# Execute
cloud_ai_100_exec_kv(
compilation_batch_size=compilation_batch_size,
tokenizer=tokenizer,
qpc=qpc_path,
device_id=devices,
prompt=prompt,
)


if __name__ == "__main__":
Expand All @@ -49,9 +73,14 @@ def main(
parser.add_argument("--qpc_path", "--qpc-path", required=True, help="Path to generated QPC")
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
default="My name is",
help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol",
type=str,
help="Input prompt, if executing for batch size>1, use prompts_txt_file_path flag",
)
parser.add_argument(
"--prompts_txt_file_path",
"--prompts-txt-file-path-file-path",
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
type=str,
help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument(
"--device_group",
Expand All @@ -67,4 +96,12 @@ def main(
"--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models"
)
args = parser.parse_args()
main(args.model_name, args.prompt, args.qpc_path, args.device_group, args.cache_dir, args.hf_token)
main(
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
args.model_name,
args.prompt,
args.prompts_txt_file_path,
args.qpc_path,
args.device_group,
args.cache_dir,
args.hf_token,
)
58 changes: 49 additions & 9 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import QEfficient
from QEfficient.cloud.compile import main as compile
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.generation.text_generation_inference import (
check_batch_size_and_num_prompts,
cloud_ai_100_exec_kv,
get_compilation_batch_size,
read_prompts_txt_file,
)
from QEfficient.utils import hf_download
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
from QEfficient.utils.logging_utils import logger
Expand Down Expand Up @@ -48,7 +53,8 @@ def onnx_exists(onnx_file_path: str) -> bool:
def main(
model_name: str,
num_cores: int,
prompt: str,
prompt: str = None,
prompts_txt_file_path: str = None,
aic_enable_depth_first: bool = False,
mos: int = -1,
cache_dir: str = Constants.CACHE_DIR,
Expand Down Expand Up @@ -76,6 +82,13 @@ def main(
onnx_dir_path = os.path.join(model_card_dir, "onnx")
onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")

assert (prompt is None and prompts_txt_file_path is not None) or (
prompt is not None and prompts_txt_file_path is None
), "Please pass either single input string using --prompt or multiple inputs using --prompts_txt_file_path"

if prompts_txt_file_path is not None:
prompt = read_prompts_txt_file(prompts_txt_file_path)

# Get tokenizer
if hf_token is not None:
login(hf_token)
Expand All @@ -89,9 +102,19 @@ def main(
if qpc_exists(qpc_dir_path):
# execute
logger.info("Pre-compiled qpc found! Trying to execute with given prompt")
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
compilation_batch_size = get_compilation_batch_size(qpc_dir_path)
check_batch_size_and_num_prompts(prompt, compilation_batch_size)
cloud_ai_100_exec_kv(
compilation_batch_size=compilation_batch_size,
tokenizer=tokenizer,
qpc_path=qpc_dir_path,
device_id=device_group,
prompt=prompt,
)
return

check_batch_size_and_num_prompts(prompt, batch_size)

if onnx_exists(onnx_model_path):
# Compile -> execute
# We need to pass parent directory of qpc_dir_path, as the compile function handles the qpcs directory creation
Expand All @@ -110,7 +133,13 @@ def main(
assert (
generated_qpc_path == qpc_dir_path
), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}"
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
cloud_ai_100_exec_kv(
compilation_batch_size=compilation_batch_size,
tokenizer=tokenizer,
qpc_path=qpc_dir_path,
device_id=device_group,
prompt=prompt,
)
return

#############################################
Expand Down Expand Up @@ -157,12 +186,18 @@ def main(
logger.info(f"Compiled qpc files can be found at : {generated_qpc_path}")

# Execute
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
cloud_ai_100_exec_kv(
compilation_batch_size=compilation_batch_size,
tokenizer=tokenizer,
qpc_path=qpc_dir_path,
device_id=device_group,
prompt=prompt,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inference command, the model will be downloaded from HF, optmized, compiled, executed on AIC"
description="Inference command, the model will be downloaded from HF, optmized, compiled, executed on Cloud AI 100"
)
parser.add_argument("--model-name", "--model_name", required=True, help="HF Model card name/id")
parser.add_argument(
Expand Down Expand Up @@ -191,9 +226,14 @@ def main(
)
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
default="My name is",
help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol",
type=str,
help="Input prompt, if executing for batch size>1, use prompts_txt_file_path flag",
)
parser.add_argument(
"--prompts_txt_file_path",
"--prompts-txt-file-path",
type=str,
help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument(
"--aic_enable_depth_first",
Expand Down
103 changes: 96 additions & 7 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,75 @@ def write_io_files(
json.dump({"IO-files": io_files}, fp, indent=True)


def get_compilation_batch_size(qpc_path: str):
qpc_base_path = os.path.dirname(qpc_path)
print(qpc_base_path)
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
specialization_file_path = os.path.join(qpc_base_path, "specializations.json")
print(specialization_file_path)
with open(specialization_file_path, "r") as file:
data = json.load(file)
compilation_batch_size = int(data["specializations"][0]["batch_size"])
return compilation_batch_size


def check_batch_size_and_num_prompts(prompt: Union[str, List], compilation_batch_size: int):
if isinstance(prompt, list):
num_prompts = len(prompt)
elif isinstance(prompt, str):
num_prompts = 1
else:
print("Input prompt sould be either string for single input or List of string in case of mutliple inputs")
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
if compilation_batch_size > 1:
assert (
compilation_batch_size == num_prompts
), f"Mismatch between number of prompts {num_prompts} and compilation batch size {compilation_batch_size}; please pass correct input argument"


def read_prompts_txt_file(prompts_txt_file_path: str):
prompt = []
with open(prompts_txt_file_path, "r") as file:
for line in file:
prompt.append(line.strip())
return prompt


def cloud_ai_100_exec_kv(
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
qpc_path: str,
prompt: Union[str, List],
compilation_batch_size: int,
device_id: List[int] = [0],
):
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
if compilation_batch_size == 1 and isinstance(prompt, list):
for i in range(len(prompt)):
latency_stats = exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=device_id, prompt=prompt[i])
if i == len(prompt) - 1:
generated_texts, prefill_time, decode_perf, total_perf, total_time = latency_stats
print_latency_stats_kv(
prompt,
generated_texts,
compilation_batch_size,
prefill_time,
decode_perf,
total_perf,
total_time,
automation=False,
)
else:
latency_stats = exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=device_id, prompt=prompt)
generated_texts, prefill_time, decode_perf, total_perf, total_time = latency_stats
print_latency_stats_kv(
prompt,
generated_texts,
compilation_batch_size,
prefill_time,
decode_perf,
total_perf,
total_time,
automation=False,
)


def latency_stats_bertstyle(
model_name: str,
qpc: str,
Expand Down Expand Up @@ -97,25 +166,26 @@ def latency_stats_bertstyle(
print(round((cur_len - init_len) / (end - start), 2), "tok/s")


def cloud_ai_100_exec_kv(
def exec_kv(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
qpc: str,
prompt: str,
prompt: Union[str, List],
input_len: Optional[int] = None,
generation_len: Optional[int] = None,
device_id: List[int] = [0],
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
automation: bool = False,
):
if tokenizer.padding_side != "left":
logger.warning(f"Please use padding_side='left' while initializing the tokenizer")
logger.warning("Please use padding_side='left' while initializing the tokenizer")
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load QPC
session = QAICInferenceSession(qpc, device_id, enable_debug_logs=enable_debug_logs)

# Read prompt and ctx len from session
prompt_len = max([x[session.binding_index_map["input_ids"]][1][1] for x in session.allowed_shapes])
ctx_len = session.allowed_shapes[0][session.binding_index_map["attention_mask"]][1][1]
Expand All @@ -126,11 +196,11 @@ def cloud_ai_100_exec_kv(
num_chunks = -(input_len // -prompt_len) # ceil divide without float
input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len
assert input_len <= ctx_len, "input_len should be less than ctx_len"

# Skip inputs/outputs
session.skip_buffers([x for x in session.input_names if x.startswith("past_")])
session.skip_buffers([x for x in session.output_names if x.endswith("_RetainedState")])
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare inputs for first iteration
start = perf_counter()
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=input_len)
Expand All @@ -146,8 +216,13 @@ def cloud_ai_100_exec_kv(
cache_index = np.array([0])
inputs["cache_index"] = cache_index
generated_ids = np.full((batch_size, generation_len - input_len + 1), tokenizer.pad_token_id)

if stream:
print(0, prompt[0], end=" ", flush=True)
if isinstance(prompt, list):
print(0, prompt[0], end=" ", flush=True)
else:
print(0, prompt, end=" ", flush=True)

# Run prefill
for i in range(num_chunks):
chunk_inputs = inputs.copy()
Expand All @@ -159,6 +234,7 @@ def cloud_ai_100_exec_kv(
if write_io_dir:
write_io_files(inputs, outputs, write_io_dir, "prefill", "aic_batch_io", True, False)
cache_index += prompt_len

# Get first token
logits = outputs["logits"]
if len(logits.shape) == 2:
Expand All @@ -169,6 +245,7 @@ def cloud_ai_100_exec_kv(
generated_ids[:, cache_index[0] - input_len] = next_token_id.squeeze(1)
if stream:
print(tokenizer.decode(next_token_id[0]), end=" ", flush=True)

# Skip attention_mask from next iteration to use retained attention_mask
session.skip_buffers(["attention_mask"])
loop_start = perf_counter()
Expand All @@ -178,6 +255,7 @@ def cloud_ai_100_exec_kv(
if write_io_dir:
write_io_files(inputs, outputs, write_io_dir, "decode", "aic_batch_io", True, False)
write_io_dir = None

# Prepare inputs for next iteration
logits = outputs["logits"]
if len(logits.shape) == 2:
Expand All @@ -192,14 +270,24 @@ def cloud_ai_100_exec_kv(
print(tokenizer.decode(next_token_id[0]), end=" ", flush=True)
end = perf_counter()
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

for i in range(1 if stream else 0, batch_size):
print()
print(i, prompt[i], generated_texts[i])

prefill_time = loop_start - start
decode_perf = (cache_index.item() - input_len - 1) / (end - loop_start)
total_perf = (cache_index.item() - input_len) / (end - start)
total_time = end - start
print()

latency_stats = (generated_texts, prefill_time, decode_perf, total_perf, total_time)
return latency_stats


def print_latency_stats_kv(
prompt, generated_texts, batch_size, prefill_time, decode_perf, total_perf, total_time, automation: bool = False
):
if automation:
print()
print("input=", prompt)
Expand All @@ -210,6 +298,7 @@ def cloud_ai_100_exec_kv(
print("Total (E2E) inference time is=", round(total_time, 2))
return
print()

print("===================== Performance Stats =====================")
quic-mamta marked this conversation as resolved.
Show resolved Hide resolved
if batch_size > 1:
print("Prefill time a.k.a TTFT (batch) is :", round(prefill_time, 2), "s")
Expand Down
Loading
Loading