From 643bb2c02f7378e671f1416c6342f9f21ba5ffc7 Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Tue, 27 Aug 2024 00:32:31 +0530 Subject: [PATCH] Added SplitTensorsTransform & Auto-device-picker (#89) * fix: Transform names Signed-off-by: Ilango Rajagopal * OnnxTransforms need **kwargs Signed-off-by: Ilango Rajagopal * Added `SplitTensorsTransform` Signed-off-by: Ilango Rajagopal * fix: `onnx_base_dir` should be passed as kwarg Signed-off-by: Ilango Rajagopal * Auto-device-picker for QAICInferenceSession Signed-off-by: Ilango Rajagopal * Make device_id optional Signed-off-by: Ilango Rajagopal * Use auto-device-picker for tests Signed-off-by: Ilango Rajagopal * Remove LoraAdapters placeholder Signed-off-by: Ilango Rajagopal * Fix compile API to use None device_group Signed-off-by: Ilango Rajagopal * fix: get_qpc_dir when device_group=None Signed-off-by: Ilango Rajagopal * Parallelizing pytests Signed-off-by: Onkar Chougule * fixed parallel tests Signed-off-by: Onkar Chougule * fixing parallel tests Signed-off-by: Onkar Chougule * linter Signed-off-by: Onkar Chougule * Update docstring for optional `device_group` Signed-off-by: Ilango Rajagopal * Add docstrings to new transforms Signed-off-by: Ilango Rajagopal * fixed parallelizing tests Signed-off-by: Onkar Chougule * parallelzing cli tests too Signed-off-by: Onkar Chougule * extra attempt to reduce tests time Signed-off-by: Onkar Chougule * Fix docstring for compile function Signed-off-by: Ilango Rajagopal * fixed junit xml files Signed-off-by: Onkar Chougule * bugfix Signed-off-by: Onkar Chougule * fix: typo on pytest-xdist Signed-off-by: Ilango Rajagopal * Move junit_logging init option to pyproject.toml Signed-off-by: Ilango Rajagopal * Move standard pytest flags to pyproject.toml Signed-off-by: Ilango Rajagopal * fix pyproject.toml Signed-off-by: Ilango Rajagopal --------- Signed-off-by: Ilango Rajagopal Signed-off-by: Onkar Chougule Co-authored-by: Onkar Chougule --- QEfficient/base/onnx_transforms.py | 50 ++++++-- QEfficient/cloud/execute.py | 5 +- QEfficient/cloud/infer.py | 5 +- QEfficient/compile/compile_helper.py | 8 +- QEfficient/exporter/export_utils.py | 4 +- QEfficient/generation/cloud_infer.py | 22 ++-- .../generation/text_generation_inference.py | 8 +- QEfficient/utils/_utils.py | 2 +- QEfficient/utils/run_utils.py | 2 +- pyproject.toml | 2 + scripts/Jenkinsfile | 5 +- tests/base/test_onnx_transforms.py | 92 +++++++++----- tests/cloud/conftest.py | 8 +- tests/cloud/high_level_testing.json | 4 +- tests/cloud/test_compile.py | 3 + tests/cloud/test_execute.py | 3 +- tests/cloud/test_export.py | 3 + tests/cloud/test_infer.py | 2 +- .../models/test_causal_lm_models.py | 109 +++++++++++------ tests/utils.py | 112 +----------------- 20 files changed, 229 insertions(+), 220 deletions(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index a55e773b..543ec4e2 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -20,11 +20,11 @@ def __init__(self): raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") @classmethod - def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ModelProto, bool]: + def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: """ Override this class to apply a transformation. :param model: The model's ONNX graph to transform - :param onnx_base_dir: Directory where the model and external files are present + :param kwargs: Parameters needed for specific transforms. All transforms should take **kwargs to ignore unneeded kwargs. :returns: ONNX graph after applying the transform :returns: Boolean indicating whether transform was applied @@ -32,13 +32,16 @@ def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ raise NotImplementedError("Use subclasses for ONNX transform") -class FP16Clip(OnnxTransform): +class FP16ClipTransform(OnnxTransform): """ Clips the tensor values to be in FP16 range. """ @classmethod - def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ModelProto, bool]: + def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: + """ + :param onnx_base_dir: Base directory to load tensors (if not already loaded). + """ finfo = np.finfo(np.float16) fp16_max = finfo.max fp16_min = finfo.min @@ -53,9 +56,38 @@ def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ return model, transformed -class SplitWeights(OnnxTransform): - pass - +class SplitTensorsTransform(OnnxTransform): + """ + Split external tensors file + """ -class LoraAdapters(OnnxTransform): - pass + @classmethod + def apply( + cls, + model: ModelProto, + *, + model_name: str, + onnx_base_dir: Optional[str] = None, + file_chunk_size: int = 10 * 2**30, # 10 GiB + size_threshold: int = 1024, + **kwargs, + ) -> Tuple[ModelProto, bool]: + """ + :param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data + :param onnx_base_dir: Base directory to load tensors (if not already loaded). + :param file_chunk_size: Chunk size to split external files into. + :param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally. + """ + file_num = 0 + current_file_size = 0 + transformed = False + external_data_helper.load_external_data_for_model(model, onnx_base_dir) + for tensor in external_data_helper._get_all_tensors(model): + if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): + transformed = True + current_file_size += tsize + if current_file_size > file_chunk_size: + file_num += 1 + current_file_size = tsize + external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") + return model, transformed diff --git a/QEfficient/cloud/execute.py b/QEfficient/cloud/execute.py index c6145dfd..483303b0 100644 --- a/QEfficient/cloud/execute.py +++ b/QEfficient/cloud/execute.py @@ -16,7 +16,7 @@ def main( model_name: str, qpc_path: str, - device_group: List[int], + device_group: Optional[List[int]] = None, local_model_dir: Optional[str] = None, prompt: Optional[str] = None, # type: ignore prompts_txt_file_path: Optional[str] = None, @@ -30,8 +30,8 @@ def main( ``Mandatory`` Args: :model_name (str): Hugging Face Model Card name, Example: ``gpt2``. :qpc_path (str): Path to the generated binary after compilation. - :device_group (List[int]): Device Ids to be used for compilation. if len(device_group) > 1. Multiple Card setup is enabled. ``Optional`` Args: + :device_group (List[int]): Device Ids to be used for compilation. if len(device_group) > 1. Multiple Card setup is enabled. ``Defaults to None.`` :local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.`` :prompt (str): Sample prompt for the model text generation. ``Defaults to None.`` :prompts_txt_file_path (str): Path to txt file for multiple input prompts. ``Defaults to None.`` @@ -69,7 +69,6 @@ def main( parser.add_argument( "--device_group", "--device-group", - required=True, type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], help="Cloud AI 100 device ids (comma-separated) e.g. [0]", ) diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 44d93933..75ef255c 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -20,7 +20,7 @@ def main( model_name: str, num_cores: int, - device_group: List[int], + device_group: Optional[List[int]] = None, prompt: Optional[str] = None, # type: ignore prompts_txt_file_path: Optional[str] = None, aic_enable_depth_first: bool = False, @@ -39,8 +39,8 @@ def main( ``Mandatory`` Args: :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` :num_cores (int): Number of cores to compile model on. - :device_group (List[int]): Device Ids to be used for compilation. If ``len(device_group) > 1``, multiple Card setup is enabled. ``Optional`` Args: + :device_group (List[int]): Device Ids to be used for compilation. If ``len(device_group) > 1``, multiple Card setup is enabled. ``Defaults to None.`` :prompt (str): Sample prompt for the model text generation. ``Defaults to None.`` :prompts_txt_file_path (str): Path to txt file for multiple input prompts. ``Defaults to None.`` :aic_enable_depth_first (bool): Enables ``DFS`` with default memory size. ``Defaults to False.`` @@ -147,7 +147,6 @@ def main( parser.add_argument( "--device_group", "--device-group", - required=True, type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], help="Cloud AI 100 device ids (comma-separated) e.g. [0,1] ", ) diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index bcd4c7fe..5d9e919c 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -40,7 +40,7 @@ def compile_kv_model_on_cloud_ai_100( custom_io_path: str, aic_enable_depth_first: bool, mos: int = -1, - device_group: List[int] = [0], + device_group: Optional[List[int]] = None, **kwargs, ) -> Tuple[bool, str]: if kwargs: @@ -74,7 +74,7 @@ def compile_kv_model_on_cloud_ai_100( command.append(f"-mos={mos}") if aic_enable_depth_first: command.append("-aic-enable-depth-first") - if len(device_group) > 1: + if device_group is not None and len(device_group) > 1: mdp_ts_config = { "connections": [{"devices": list(range(len(device_group))), "type": "p2p"}], "partitions": [ @@ -101,7 +101,7 @@ def compile( onnx_path: str, qpc_path: str, num_cores: int, - device_group: List[int], # FIXME: use num_devices instead + device_group: Optional[List[int]] = None, # FIXME: use num_devices instead aic_enable_depth_first: bool = False, mos: int = -1, batch_size: int = 1, @@ -122,8 +122,8 @@ def compile( :onnx_path (str): Generated ``ONNX`` Model Path. :qpc_path (str): Path for saving compiled qpc binaries. :num_cores (int): Number of cores to compile the model on. - :device_group (List[int]): Used for finding the number of devices to compile for. ``Optional`` Args: + :device_group (List[int]): Used for finding the number of devices to compile for. ``Defaults to None.`` :aic_enable_depth_first (bool): Enables ``DFS`` with default memory size. ``Defaults to False.`` :mos (int): Effort level to reduce the on-chip memory. ``Defaults to -1.`` :batch_size (int): Batch size to compile the model for. ``Defaults to 1.`` diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py index 75ee08a8..ecf291ff 100644 --- a/QEfficient/exporter/export_utils.py +++ b/QEfficient/exporter/export_utils.py @@ -17,7 +17,7 @@ import torch from onnx import external_data_helper -from QEfficient.base.onnx_transforms import FP16Clip +from QEfficient.base.onnx_transforms import FP16ClipTransform def export_onnx( @@ -215,7 +215,7 @@ def fix_onnx_fp16( model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx")) # TODO: Remove this `fix_onnx_fp16` function and replace with this transform # as we're not utilizing the validations done in this function - model, fp16_fix = FP16Clip.apply(model, gen_models_path) + model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path) if fp16_fix: # Save FP16 model diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index aac3d60d..998dc82b 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Dict, List +from typing import Dict, List, Optional from warnings import warn import numpy as np @@ -44,7 +44,7 @@ class QAICInferenceSession: def __init__( self, qpc_path: str, - device_ids: List[int] = [0], + device_ids: Optional[List[int]] = None, activate: bool = True, enable_debug_logs: bool = False, ): @@ -58,9 +58,13 @@ def __init__( :enable_debug_logs: bool. If True, It will enable debug logs. Default=False. """ # Load QPC - devices = qaicrt.QIDList(device_ids) - self.context = qaicrt.Context(devices) - self.queue = qaicrt.Queue(self.context, device_ids[0]) # Async API + if device_ids is not None: + devices = qaicrt.QIDList(device_ids) + self.context = qaicrt.Context(devices) + self.queue = qaicrt.Queue(self.context, device_ids[0]) + else: + self.context = qaicrt.Context() + self.queue = qaicrt.Queue(self.context, 0) # Async API if enable_debug_logs: assert ( self.context.setLogLevel(qaicrt.QLogLevel.QL_DEBUG) == qaicrt.QStatus.QS_SUCCESS @@ -80,7 +84,7 @@ def __init__( # Create and load Program prog_properties = qaicrt.QAicProgramProperties() prog_properties.SubmitRetryTimeoutMs = 60_000 - if len(device_ids) > 1: + if device_ids and len(device_ids) > 1: prog_properties.devMapping = ":".join(map(str, device_ids)) self.program = qaicrt.Program(self.context, None, qpc, prog_properties) assert self.program.load() == qaicrt.QStatus.QS_SUCCESS, "Failed to load program" @@ -170,14 +174,14 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: for binding, (elemsize, shape), (_, passed_shape) in zip( self.bindings, allowed_shape, self.buf_dims ): - if passed_shape[0] == 0: + if passed_shape == [0]: if not binding.is_partial_buf_allowed: warn(f"Partial buffer not allowed for: {binding.name}") continue error_message += f"{binding.name}:\t{elemsize}\t{shape}\n" error_message += "\n\nPassed shapes:\n" for binding, (elemsize, shape) in zip(self.bindings, self.buf_dims): - if shape[0] == 0: + if shape == [0]: continue error_message += f"{binding.name}:\t{elemsize}\t{shape}\n" raise ValueError(error_message) @@ -188,7 +192,7 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: outputs = {} for output_name in self.output_names: buffer_index = self.binding_index_map[output_name] - if self.buf_dims[buffer_index][1][0] == 0: + if self.qbuffers[buffer_index].size == 0: continue outputs[output_name] = np.frombuffer( bytes(output_qbuffers[buffer_index]), diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index d3cd8724..27efb4bc 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -99,7 +99,7 @@ def latency_stats_bertstyle( qpc_path: str, seq_len: int, prompt: str, - device_id: List[int] = [0], + device_id: Optional[List[int]] = None, ): """ Function to execute Bertstyle ONNX model on Cloud AI 100. @@ -196,7 +196,7 @@ def cloud_ai_100_exec_kv_helper( prompt: List[str], ctx_len: int, generation_len: Optional[int] = None, - device_id: List[int] = [0], + device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -342,7 +342,7 @@ def cloud_ai_100_exec_kv( qpc_path: str, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, - device_id: List[int] = [0], + device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, enable_debug_logs: bool = False, stream: bool = True, @@ -362,7 +362,7 @@ def cloud_ai_100_exec_kv( :prompt (str): Sample prompt for the model text generation. ``Defaults to None``. :prompts_txt_file_path (str): Path of the prompt text file. ``Defaults to None``. :generation_len (int): Maximum context length for the model during compilation. ``Defaults to None``. - :device_id (List[int]): Device IDs to be used for compilation. If ``len(device_id) > 1``, it enables multiple card setup. ``Defaults to [0]``. + :device_id (List[int]): Device IDs to be used for execution. If ``len(device_id) > 1``, it enables multiple card setup. If ``None``, auto-device-picker will be used. ``Defaults to None``. :enable_debug_logs (bool): If True, it enables debugging logs. ``Defaults to False``. :stream (bool): If True, enable streamer, which returns tokens one by one as the model generates them. ``Defaults to True``. :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ca73abbd..2f5065fe 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -157,7 +157,7 @@ def get_qpc_dir_path( ) -> str: qpc_base_dir_name = ( f"qpc_{num_cores}cores_{batch_size}BS_{prompt_len}PL_{ctx_len}CL_{mos}MOS_" - + f"{len(device_group)}" + + f"{len(device_group) if device_group is not None else 1}" + "devices" + ("_mxfp6_mxint8" if (mxfp6 and mxint8) else "_mxfp6" if mxfp6 else "_fp16_mxint8" if mxint8 else "_fp16") ) diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 8acd36f4..74a33e74 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -173,7 +173,7 @@ def run_kv_model_on_ort(self, model_path): print("Completion:", repr(predicted_string)) return generated_ids - def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group): + def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None): """ Function responsible for running ``ONNX`` model on Cloud AI 100 and return the output tokens diff --git a/pyproject.toml b/pyproject.toml index a68318ec..ab2d5e94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,4 +56,6 @@ line-length = 120 lint.extend-select = ["I"] [tool.pytest.ini_options] +addopts = "-W ignore -s -v" +junit_logging = "all" doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 51687c25..48ca48ca 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -20,6 +20,7 @@ pipeline . preflight_qeff/bin/activate pip install --upgrade pip setuptools pip install .[test] + pip install junitparser pytest-xdist rm -rf QEfficient ''' } @@ -35,7 +36,9 @@ pipeline sh ''' . preflight_qeff/bin/activate export TOKENIZERS_PARALLELISM=false - pytest -W ignore -s -v tests -o junit_logging=all --junitxml=tests/tests_log.xml + pytest tests --ignore tests/cloud -n 4 --junitxml=tests/tests_log1.xml + pytest tests/cloud --junitxml=tests/tests_log2.xml + junitparser merge tests/tests_log1.xml tests/tests_log2.xml tests/tests_log.xml deactivate exit ''' diff --git a/tests/base/test_onnx_transforms.py b/tests/base/test_onnx_transforms.py index fa5c6253..dbbbbda1 100644 --- a/tests/base/test_onnx_transforms.py +++ b/tests/base/test_onnx_transforms.py @@ -5,22 +5,10 @@ # # ---------------------------------------------------------------------------- -import os -import shutil - import numpy as np import onnx -import pytest - -from QEfficient.base.onnx_transforms import FP16Clip - -@pytest.fixture -def external_path(): - external_dir = "tmp_external_data" - os.makedirs(external_dir, exist_ok=True) - yield external_dir - shutil.rmtree(external_dir) +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform def test_fp16clip_transform(): @@ -44,39 +32,87 @@ def test_fp16clip_transform(): } """) onnx.checker.check_model(test_onnx, True, True, True) - transformed_onnx, transformed = FP16Clip.apply(test_onnx) + transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx) assert transformed assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == 65504.0 assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[1]) == 2147483647 assert onnx.numpy_helper.to_array(transformed_onnx.graph.node[1].attribute[0].t) == -65504.0 -def test_fp16clip_transform_external(external_path): - external_weight_file = "fp32_min.weight" - test_onnx = onnx.parser.parse_model( - """ +def test_fp16clip_transform_external(tmp_path): + external_tensors_file = "fp32_min.raw" + test_onnx = onnx.parser.parse_model(f""" < ir_version: 8, opset_import: ["" : 17] > test_fp16clip (float [n, 32] x) => (float [n, 32] y) < - float min_val = [ "location": "" ], - float zero = {0.0} + float min_val = [ "location": "{external_tensors_file}" ], + float zero = {{0.0}} > - { + {{ mask = Greater(x, zero) y = Where(mask, x, min_val) - } - """.replace("", str(external_weight_file)) - ) + }} + """) # Write onnx and external_data - onnx_path = os.path.join(external_path, "test_fp16_clip_external.onnx") + onnx_path = tmp_path / "test_fp16_clip_external.onnx" onnx.save(test_onnx, onnx_path) - np.array(-1e10, dtype="float32").tofile(os.path.join(external_path, external_weight_file)) - + np.array(-1e10, dtype="float32").tofile(tmp_path / external_tensors_file) onnx.checker.check_model(onnx_path, True, True, True) - transformed_onnx, transformed = FP16Clip.apply(test_onnx, external_path) + + transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx, onnx_base_dir=str(tmp_path)) assert transformed assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == -65504.0 + + +def test_split_tensors_transform(tmp_path): + external_tensors_file = "tensors.raw" + test_onnx = onnx.parser.parse_model(f""" + < + ir_version: 8, + opset_import: ["": 17] + > + test_split () => () + < + float[1, 32] tensor0 = [ "location": "{external_tensors_file}", "offset": "0", "length": "{32*4}" ], + float[1, 32] tensor1 = [ "location": "{external_tensors_file}", "offset": "{32*4}", "length": "{32*4}" ], + float[1, 16] tensor2 = [ "location": "{external_tensors_file}", "offset": "{64*4}", "length": "{16*4}" ] + > + {{ + }} + """) + + # Write onnx and external_data + onnx_path = tmp_path / "test_split_pre.onnx" + onnx.save(test_onnx, onnx_path) + tensors = np.random.rand(32 + 32 + 16).astype("float32") + tensors.tofile(tmp_path / external_tensors_file) + onnx.checker.check_model(onnx_path, True, True, True) + + trans_onnx, transformed = SplitTensorsTransform.apply( + test_onnx, + model_name="test_split", + onnx_base_dir=str(tmp_path), + file_chunk_size=32 * 4, + size_threshold=16 * 4, + ) + + tensor0_ext_data = onnx.external_data_helper.ExternalDataInfo(trans_onnx.graph.initializer[0]) + assert tensor0_ext_data.location == "test_split_0.onnx.data" + + tensor1_ext_data = onnx.external_data_helper.ExternalDataInfo(trans_onnx.graph.initializer[1]) + assert tensor1_ext_data.location == "test_split_1.onnx.data" + + tensor2 = trans_onnx.graph.initializer[2] + assert tensor2.data_location == onnx.TensorProto.DataLocation.Value("DEFAULT") + assert np.all(onnx.numpy_helper.to_array(tensor2) == tensors[-16:]) + + # Save and test if all files are saved + onnx_path = tmp_path / "test_split.onnx" + onnx.save(trans_onnx, onnx_path) + assert onnx_path.is_file() + assert onnx_path.with_name(onnx_path.name.replace(".onnx", "_0.onnx.data")).is_file() + assert onnx_path.with_name(onnx_path.name.replace(".onnx", "_1.onnx.data")).is_file() diff --git a/tests/cloud/conftest.py b/tests/cloud/conftest.py index 6a3a0ded..d7a1a706 100644 --- a/tests/cloud/conftest.py +++ b/tests/cloud/conftest.py @@ -287,6 +287,8 @@ def pytest_sessionstart(session): def pytest_sessionfinish(session, exitstatus): - cache_clean_up() - qeff_models_clean_up() - logger.info("...PYTEST Session Ended.") + inside_worker = getattr(session.config, "workerinput", None) + if inside_worker is None: + cache_clean_up() + qeff_models_clean_up() + logger.info("...PYTEST Session Ended.") diff --git a/tests/cloud/high_level_testing.json b/tests/cloud/high_level_testing.json index 5735ef47..1a4ef033 100644 --- a/tests/cloud/high_level_testing.json +++ b/tests/cloud/high_level_testing.json @@ -1,6 +1,6 @@ { "license": "SEE LICENSE IN LICENSE FILE", - "model_name" : ["gpt2","TinyLlama/TinyLlama-1.1B-Chat-v1.0","Salesforce/codegen-350M-mono","wtang06/mpt-125m-c4"], + "model_name" : ["gpt2"], "num_cores" : [16], "prompt" : ["My name is"], "prompts_txt_file_path" : ["examples/prompts.txt"], @@ -13,5 +13,5 @@ "ctx_len" : [128], "mxfp6" : [1], "mxint8" : [1], - "device_group" : [[0]] + "device_group" : [null] } \ No newline at end of file diff --git a/tests/cloud/test_compile.py b/tests/cloud/test_compile.py index 68f76cf4..ceea0120 100644 --- a/tests/cloud/test_compile.py +++ b/tests/cloud/test_compile.py @@ -7,10 +7,13 @@ import os +import pytest + import QEfficient import QEfficient.cloud.compile +@pytest.mark.cli def test_compile(setup, mocker): """ test_compile is a HL compile api testing function, diff --git a/tests/cloud/test_execute.py b/tests/cloud/test_execute.py index 4b6973f1..789c1566 100644 --- a/tests/cloud/test_execute.py +++ b/tests/cloud/test_execute.py @@ -5,12 +5,14 @@ # # ----------------------------------------------------------------------------- +import pytest import QEfficient import QEfficient.cloud.execute from QEfficient.cloud.execute import main as execute +@pytest.mark.cli def test_execute(setup, mocker): """ test_execute is a HL execute api testing function, @@ -27,7 +29,6 @@ def test_execute(setup, mocker): execute( model_name=ms.model_name, qpc_path=ms.qpc_dir_path(), - device_group=ms.device_group, prompt=ms.prompt, prompts_txt_file_path=ms.prompts_txt_file_path, hf_token=ms.hf_token, diff --git a/tests/cloud/test_export.py b/tests/cloud/test_export.py index 3334bcee..4560d318 100644 --- a/tests/cloud/test_export.py +++ b/tests/cloud/test_export.py @@ -7,12 +7,15 @@ import os +import pytest + import QEfficient import QEfficient.cloud.export from QEfficient.cloud.export import main as export from QEfficient.utils.constants import Constants +@pytest.mark.cli def test_export(setup, mocker): """ test_export is a HL export api testing function, diff --git a/tests/cloud/test_infer.py b/tests/cloud/test_infer.py index 7ddd20e5..e6a61d84 100644 --- a/tests/cloud/test_infer.py +++ b/tests/cloud/test_infer.py @@ -14,6 +14,7 @@ from QEfficient.cloud.infer import main as infer +@pytest.mark.cli @pytest.mark.usefixtures("clean_up_after_test") def test_infer(setup, mocker): """ @@ -46,7 +47,6 @@ def test_infer(setup, mocker): ctx_len=ms.ctx_len, mxfp6=ms.mxfp6, mxint8=ms.mxint8, - device_group=ms.device_group, ) # tokenizer check load_hf_tokenizer_spy.assert_called_once() diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 313e666f..ffd61830 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -5,10 +5,17 @@ # # ----------------------------------------------------------------------------- +import os + import pytest +from QEfficient.compile.compile_helper import compile_kv_model_on_cloud_ai_100 +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants from QEfficient.utils.device_utils import get_available_device_id -from tests.utils import get_cloud_ai_100_tokens, set_up +from QEfficient.utils.run_utils import ApiRunner +from tests.utils import load_pytorch_model test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -26,41 +33,67 @@ ] +@pytest.mark.causal_lm @pytest.mark.parametrize("model_name", test_models) -class TestQEfficientModels: - def setup_class(cls): - """ - Set up function to set up the test environment for TestQEfficientModels class - :param cls - """ - cls.setup_infos = {model_name: set_up({"model_name": model_name}) for model_name in test_models} - - def test_qefficient_model_torch(self, model_name): - """ - Test function to validate the model before and after KV changes on Pytorch - :param model_name: Name of model. - """ - assert ( - self.setup_infos[model_name]["pytorch_hf_tokens"] == self.setup_infos[model_name]["pytorch_kv_tokens"] - ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - - def test_qefficient_model_onnx(self, model_name): - """ - Test function to validate the model before and after KV changes on ONNXRT - :param model_name: Name of model. - """ - assert ( - self.setup_infos[model_name]["pytorch_kv_tokens"] == self.setup_infos[model_name]["ort_tokens"] - ).all(), "Tokens don't match for ONNXRT output and PyTorch output." - - @pytest.mark.skipif(not get_available_device_id, reason="No available devices to run model on Cloud AI 100") - def test_qefficient_model_cloud_ai_100(self, model_name): - """ - Test function to validate the model before and after KV changes on Cloud AI 100 - :param model_name: Name of model. - """ - - cloud_ai_100_tokens = get_cloud_ai_100_tokens(self.setup_infos[model_name]) - assert ( - self.setup_infos[model_name]["ort_tokens"] == cloud_ai_100_tokens - ).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the model before and after KV changes on Pytorch + :param model_name: Name of model. + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 1 + + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_pytorch_model(model_config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + qeff_model = QEFFAutoModelForCausalLM(model_hf, f"{model_name}") + + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + assert ( + pytorch_hf_tokens == pytorch_kv_tokens + ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + base_path = os.path.dirname(onnx_model_path) + tests_qpc_dir = os.path.join(base_path, "tests_qpc") + os.makedirs(tests_qpc_dir, exist_ok=True) + + _, test_qpcs_path = compile_kv_model_on_cloud_ai_100( + onnx_path=onnx_model_path, + specializations_json="scripts/specializations.json", + num_cores=14, + base_path=tests_qpc_dir, + mxfp6=False, + custom_io_path=os.path.join(base_path, "custom_io_fp16.yaml"), + aic_enable_depth_first=False, + ) + + cloud_ai_100_tokens = api_runner.run_kv_model_on_cloud_ai_100(test_qpcs_path) + + assert (ort_tokens == cloud_ai_100_tokens).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." diff --git a/tests/utils.py b/tests/utils.py index 31ee4ec1..5f743396 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,18 +6,12 @@ # ----------------------------------------------------------------------------- import functools -import os import unittest from transformers import AutoModelForCausalLM -from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.compile.compile_helper import compile_kv_model_on_cloud_ai_100 -from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter -from QEfficient.utils import hf_download, load_hf_tokenizer -from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants -from QEfficient.utils.device_utils import get_available_device_id, is_multi_qranium_setup_available, is_qpc_size_gt_32gb -from QEfficient.utils.run_utils import ApiRunner +from QEfficient.utils import hf_download +from QEfficient.utils.device_utils import is_multi_qranium_setup_available def skip_if_mq_not_enabled(test_method): @@ -54,105 +48,3 @@ def load_pytorch_model(model_config): params = sum(p.numel() for p in model_hf.parameters()) model_hf.eval() return model_hf, params - - -def export_onnx(model_kv, tokenizer, model_name): - """ - Function to export onnx model - --------- - - :model_kv: transformed pytorch model to be exported to ONNX. - :tokenizer: model tokenizer. - :model_name: str. - - :return base_path, onnx_model_path : str - """ - onnx_dir_path = os.path.join(QEFF_MODELS_DIR, model_name) - base_path, onnx_model_path = qualcomm_efficient_converter( - model_name=model_name, - model_kv=QEFFAutoModelForCausalLM(model=model_kv, pretrained_model_name_or_path=model_name), # type: ignore - tokenizer=tokenizer, - onnx_dir_path=onnx_dir_path, - kv=True, - ) - return base_path, onnx_model_path - - -def set_up(model_config, device_group=[0]): - """ - Set up function to set up the test environment for TestQEfficientModel class - """ - if model_config["model_name"] == "microsoft/Phi-3-mini-4k-instruct": - n_layer = 2 # test only 2 layer models - else: - n_layer = 1 - - model_config["n_layer"] = n_layer - - mxfp6 = False - model_hf, params = load_pytorch_model(model_config) - qpc_gt_32gb = is_qpc_size_gt_32gb(params, mxfp6) - - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_config["model_name"]) - config = model_hf.config - batch_size = len(Constants.INPUT_STR) - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - Constants.INPUT_STR, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - ) - try: - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - except Exception as e: - print(f"Pytorch HuggingFace Pytorch Model run failed due to : {e}") - - qeff_model = QEFFAutoModelForCausalLM(model_hf, f"{model_config['model_name']}") - - pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - - onnx_model_path = qeff_model.export() - ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) - - setup_info = {} - setup_info["model_config"] = model_config - setup_info["device_group"] = device_group - setup_info["api_runner"] = api_runner - setup_info["qpc_gt_32gb"] = qpc_gt_32gb - setup_info["pytorch_hf_tokens"] = pytorch_hf_tokens - setup_info["pytorch_kv_tokens"] = pytorch_kv_tokens - setup_info["onnx_model_path"] = onnx_model_path - setup_info["ort_tokens"] = ort_tokens - return setup_info - - -def get_cloud_ai_100_tokens(setup_info): - """ - Test function to validate the llama model before and after KV changes on Cloud AI 100 - :param None - """ - device_id = get_available_device_id() - base_path = os.path.dirname(setup_info["onnx_model_path"]) - tests_qpc_dir = os.path.join(base_path, "tests_qpc") - os.makedirs(tests_qpc_dir, exist_ok=True) - if device_id: - _, test_qpcs_path = compile_kv_model_on_cloud_ai_100( - onnx_path=setup_info["onnx_model_path"], - specializations_json="scripts/specializations.json", - num_cores=14, - base_path=tests_qpc_dir, - mxfp6=False, - custom_io_path=os.path.join(base_path, "custom_io_fp16.yaml"), - aic_enable_depth_first=False, - device_group=setup_info["device_group"], - ) - try: - cloud_ai_100_tokens = setup_info["api_runner"].run_kv_model_on_cloud_ai_100( - test_qpcs_path, setup_info["device_group"] - ) - except Exception as e: - print(f"ONNX Model run on Cloud AI 100 failed due to : {e}") - - return cloud_ai_100_tokens