diff --git a/QEfficient/cloud/execute.py b/QEfficient/cloud/execute.py index 2bd5626e..9734e3d0 100644 --- a/QEfficient/cloud/execute.py +++ b/QEfficient/cloud/execute.py @@ -19,6 +19,7 @@ def main( model_name: str, prompt: str, + inputs_file_path: str, qpc_path: str, devices: List[int], cache_dir: str = Constants.CACHE_DIR, @@ -38,7 +39,29 @@ def main( model_hf_path = hf_download(repo_id=model_name, cache_dir=cache_dir, allow_patterns=["*.json"]) tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left") - cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt) + if inputs_file_path is not None: + try: + prompt = [] + with open(inputs_file_path, "r") as file: + for line in file: + prompt.append(line.strip()) + except FileNotFoundError: + print("inputs file not found.") + + qpc_dir_name = qpc_path.strip("/").split("/")[-2] + compilation_batch_size = int(qpc_dir_name.split("BS")[0].split("_")[-1]) + + if compilation_batch_size > 1: + assert ( + compilation_batch_size == len(prompt) + ), "Mismatch between number of prompts {len(prompt)} and compilation batch size {compilation_batch_size}; please pass correct input argument" + + # Execute + if compilation_batch_size == 1 and isinstance(prompt, list): + for i in range(len(prompt)): + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt[i]) + else: + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_path, device_id=devices, prompt=prompt) if __name__ == "__main__": @@ -49,9 +72,14 @@ def main( parser.add_argument("--qpc_path", "--qpc-path", required=True, help="Path to generated QPC") parser.add_argument( "--prompt", - type=lambda prompt: prompt.split("|"), + type=str, default="My name is", - help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol", + help="Input prompt, if executing for batch size>1, use inputs_file_path flag", + ) + parser.add_argument( + "--inputs_file_path", + type=str, + help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder", ) parser.add_argument( "--device_group", @@ -67,4 +95,12 @@ def main( "--hf-token", "--hf_token", default=None, type=str, required=False, help="HF token id for private HF models" ) args = parser.parse_args() - main(args.model_name, args.prompt, args.qpc_path, args.device_group, args.cache_dir, args.hf_token) + main( + args.model_name, + args.prompt, + args.inputs_file_path, + args.qpc_path, + args.device_group, + args.cache_dir, + args.hf_token, + ) diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 3492874a..de5c2743 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -49,6 +49,7 @@ def main( model_name: str, num_cores: int, prompt: str, + inputs_file_path: str, aic_enable_depth_first: bool = False, mos: int = -1, cache_dir: str = Constants.CACHE_DIR, @@ -76,6 +77,20 @@ def main( onnx_dir_path = os.path.join(model_card_dir, "onnx") onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx") + if inputs_file_path is not None: + try: + prompt = [] + with open(inputs_file_path, "r") as file: + for line in file: + prompt.append(line.strip()) + except FileNotFoundError: + print("Inputs file not found.") + + if batch_size > 1: + assert ( + batch_size == len(prompt) + ), "Mismatch between number of prompts {len(prompt)} and batch size {batch_size}; please pass correct input argument" + # Get tokenizer if hf_token is not None: login(hf_token) @@ -89,7 +104,11 @@ def main( if qpc_exists(qpc_dir_path): # execute logger.info("Pre-compiled qpc found! Trying to execute with given prompt") - cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt) + if batch_size == 1 and isinstance(prompt, list): + for i in range(len(prompt)): + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt[i]) + else: + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt) return if onnx_exists(onnx_model_path): @@ -110,7 +129,11 @@ def main( assert ( generated_qpc_path == qpc_dir_path ), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}" - cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) + if batch_size == 1 and isinstance(prompt, list): + for i in range(len(prompt)): + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt[i]) + else: + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt) return ############################################# @@ -157,7 +180,11 @@ def main( logger.info(f"Compiled qpc files can be found at : {generated_qpc_path}") # Execute - cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) + if batch_size == 1 and isinstance(prompt, list): + for i in range(len(prompt)): + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt[i]) + else: + cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) if __name__ == "__main__": @@ -191,9 +218,15 @@ def main( ) parser.add_argument( "--prompt", - type=lambda prompt: prompt.split("|"), + type=str, default="My name is", - help="Input prompt, if executing for batch size>1, pass input promprs in single string but seperate with pipe (|) symbol", + help="Input prompt, if executing for batch size>1, use inputs_file_path flag", + ) + parser.add_argument( + "--inputs_file_path", + "--inputs-file-path", + type=str, + help="for batch size>1, pass input prompts in txt file, sample prompts.txt file present in examples folder", ) parser.add_argument( "--aic_enable_depth_first", diff --git a/examples/prompts.txt b/examples/prompts.txt new file mode 100644 index 00000000..a91a5151 --- /dev/null +++ b/examples/prompts.txt @@ -0,0 +1,3 @@ +My name is +The sun rises from +The flat earth theory is the belief that \ No newline at end of file