Skip to content

Commit

Permalink
[QEff]: Update infer and execute API to take prompts from txt file fo…
Browse files Browse the repository at this point in the history
…r bs>1

Signed-off-by: mamtsing <[email protected]>
  • Loading branch information
quic-mamta committed May 16, 2024
1 parent 700236a commit 861e60c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
44 changes: 40 additions & 4 deletions QEfficient/cloud/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
def main(
model_name: str,
prompt: str,
inputs_file_path: str,
qpc_path: str,
devices: List[int],
cache_dir: str = Constants.CACHE_DIR,
Expand All @@ -38,7 +39,29 @@ def main(
model_hf_path = hf_download(repo_id=model_name, cache_dir=cache_dir, allow_patterns=["*.json"])
tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left")

cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt)
if inputs_file_path is not None:
try:
prompt = []
with open(inputs_file_path, "r") as file:
for line in file:
prompt.append(line.strip())
except FileNotFoundError:
print("inputs file not found.")

qpc_dir_name = qpc_path.strip("/").split("/")[-2]
compilation_batch_size = int(qpc_dir_name.split("BS")[0].split("_")[-1])

if compilation_batch_size > 1:
assert (
compilation_batch_size == len(prompt)
), "Mismatch between number of prompts {len(prompt)} and compilation batch size {compilation_batch_size}; please pass correct input argument"

# Execute
if compilation_batch_size == 1 and isinstance(prompt, list):
for i in range(len(prompt)):
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt[i])
else:
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt)


if __name__ == "__main__":
Expand All @@ -49,9 +72,14 @@ def main(
parser.add_argument("--qpc_path", "--qpc-path", required=True, help="Path to generated QPC")
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
type=str,
default="My name is",
help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol",
help="Input prompt, if executing for batch size>1, use inputs_file_path flag",
)
parser.add_argument(
"--inputs_file_path",
type=str,
help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument(
"--device_group",
Expand All @@ -67,4 +95,12 @@ def main(
"--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models"
)
args = parser.parse_args()
main(args.model_name, args.prompt, args.qpc_path, args.device_group, args.cache_dir, args.hf_token)
main(
args.model_name,
args.prompt,
args.inputs_file_path,
args.qpc_path,
args.device_group,
args.cache_dir,
args.hf_token,
)
43 changes: 38 additions & 5 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def main(
model_name: str,
num_cores: int,
prompt: str,
inputs_file_path: str,
aic_enable_depth_first: bool = False,
mos: int = -1,
cache_dir: str = Constants.CACHE_DIR,
Expand Down Expand Up @@ -76,6 +77,20 @@ def main(
onnx_dir_path = os.path.join(model_card_dir, "onnx")
onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")

if inputs_file_path is not None:
try:
prompt = []
with open(inputs_file_path, "r") as file:
for line in file:
prompt.append(line.strip())
except FileNotFoundError:
print("Inputs file not found.")

if batch_size > 1:
assert (
batch_size == len(prompt)
), "Mismatch between number of prompts {len(prompt)} and batch size {batch_size}; please pass correct input argument"

# Get tokenizer
if hf_token is not None:
login(hf_token)
Expand All @@ -89,7 +104,11 @@ def main(
if qpc_exists(qpc_dir_path):
# execute
logger.info("Pre-compiled qpc found! Trying to execute with given prompt")
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
if batch_size == 1 and isinstance(prompt, list):
for i in range(len(prompt)):
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt[i])
else:
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
return

if onnx_exists(onnx_model_path):
Expand All @@ -110,7 +129,11 @@ def main(
assert (
generated_qpc_path == qpc_dir_path
), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}"
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
if batch_size == 1 and isinstance(prompt, list):
for i in range(len(prompt)):
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt[i])
else:
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt)
return

#############################################
Expand Down Expand Up @@ -157,7 +180,11 @@ def main(
logger.info(f"Compiled qpc files can be found at : {generated_qpc_path}")

# Execute
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)
if batch_size == 1 and isinstance(prompt, list):
for i in range(len(prompt)):
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt[i])
else:
cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt)


if __name__ == "__main__":
Expand Down Expand Up @@ -191,9 +218,15 @@ def main(
)
parser.add_argument(
"--prompt",
type=lambda prompt: prompt.split("|"),
type=str,
default="My name is",
help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol",
help="Input prompt, if executing for batch size>1, use inputs_file_path flag",
)
parser.add_argument(
"--inputs_file_path",
"--inputs-file-path",
type=str,
help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder",
)
parser.add_argument(
"--aic_enable_depth_first",
Expand Down
3 changes: 3 additions & 0 deletions examples/prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
My name is
The sun rises from
The flat earth theory is the belief that

0 comments on commit 861e60c

Please sign in to comment.