From a532fea3f0736bdeb8e21b32673d64745f3d80c3 Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Wed, 14 Feb 2024 11:36:41 -0800 Subject: [PATCH 1/3] putting onnx package to be optional --- onnxruntime_extensions/__init__.py | 70 ++++++++++----- onnxruntime_extensions/_ocos.py | 123 +++++--------------------- onnxruntime_extensions/_ortapi2.py | 135 ++++++++++++++++++++++------- requirements-dev.txt | 5 +- requirements.txt | 1 - setup.py | 77 +++++++--------- 6 files changed, 208 insertions(+), 203 deletions(-) delete mode 100644 requirements.txt diff --git a/onnxruntime_extensions/__init__.py b/onnxruntime_extensions/__init__.py index afd41b881..979d6f883 100644 --- a/onnxruntime_extensions/__init__.py +++ b/onnxruntime_extensions/__init__.py @@ -10,37 +10,61 @@ __author__ = "Microsoft" -__all__ = [ - 'gen_processing_models', - 'ort_inference', - 'get_library_path', - 'Opdef', 'onnx_op', 'PyCustomOpDef', 'PyOp', - 'enable_py_op', - 'expand_onnx_inputs', - 'hook_model_op', - 'default_opset_domain', - 'OrtPyFunction', 'PyOrtFunction', - 'optimize_model', - 'make_onnx_model', - 'ONNXRuntimeError', - 'hash_64', - '__version__', -] from ._version import __version__ from ._ocos import get_library_path from ._ocos import Opdef, PyCustomOpDef from ._ocos import hash_64 from ._ocos import enable_py_op -from ._ocos import expand_onnx_inputs -from ._ocos import hook_model_op from ._ocos import default_opset_domain -from ._cuops import * # noqa -from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility -from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model -from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException -from .cvt import gen_processing_models + + +_lib_only = False + +try: + import onnx # noqa + import onnxruntime # noqa +except ImportError: + _lib_only = True + pass + + +_offline_api = [ + "gen_processing_models", + "ort_inference", + "OrtPyFunction", + "PyOrtFunction", + "optimize_model", + "make_onnx_model", + "ONNXRuntimeError", +] + +__all__ = [ + "get_library_path", + "Opdef", + "onnx_op", + "PyCustomOpDef", + "PyOp", + "enable_py_op", + "expand_onnx_inputs", + "hook_model_op", + "default_opset_domain", + "hash_64", + "__version__", +] # rename the implementation with a more formal name onnx_op = Opdef.declare PyOp = PyCustomOpDef + + +if not _lib_only: + __all__ += _offline_api + + from ._cuops import * # noqa + from ._ortapi2 import hook_model_op + from ._ortapi2 import expand_onnx_inputs + from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model + from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility + from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException # noqa + from .cvt import gen_processing_models diff --git a/onnxruntime_extensions/_ocos.py b/onnxruntime_extensions/_ocos.py index 50f0f0046..e1f536814 100644 --- a/onnxruntime_extensions/_ocos.py +++ b/onnxruntime_extensions/_ocos.py @@ -7,34 +7,36 @@ """ import os import sys -import copy import glob -import onnx -from onnx import helper def _search_cuda_dir(): - paths = os.getenv('PATH', '').split(os.pathsep) + paths = os.getenv("PATH", "").split(os.pathsep) for path in paths: - for filename in glob.glob(os.path.join(path, 'cudart64*.dll')): + for filename in glob.glob(os.path.join(path, "cudart64*.dll")): return os.path.dirname(filename) return None -if sys.platform == 'win32': +if sys.platform == "win32": from . import _version # noqa: E402 - if hasattr(_version, 'cuda'): + + if hasattr(_version, "cuda"): cuda_path = _search_cuda_dir() if cuda_path is None: - raise RuntimeError( - "Cannot locate CUDA directory in the environment variable for GPU package") + raise RuntimeError("Cannot locate CUDA directory in the environment variable for GPU package") os.add_dll_directory(cuda_path) from ._extensions_pydll import ( # noqa - PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain) + PyCustomOpDef, + enable_py_op, + add_custom_op, + hash_64, + default_opset_domain, +) def get_library_path(): @@ -42,12 +44,11 @@ def get_library_path(): The custom operator library binary path :return: A string of this library path. """ - mod = sys.modules['onnxruntime_extensions._extensions_pydll'] + mod = sys.modules["onnxruntime_extensions._extensions_pydll"] return mod.__file__ class Opdef: - _odlist = {} def __init__(self, op_type, func): @@ -57,14 +58,14 @@ def __init__(self, op_type, func): @staticmethod def declare(*args, **kwargs): - if len(args) > 0 and hasattr(args[0], '__call__'): + if len(args) > 0 and hasattr(args[0], "__call__"): raise RuntimeError("Unexpected arguments {}.".format(args)) # return Opdef._create(args[0]) return lambda f: Opdef.create(f, *args, **kwargs) @staticmethod def create(func, *args, **kwargs): - name = kwargs.get('op_type', None) + name = kwargs.get("op_type", None) op_type = name or func.__name__ opdef = Opdef(op_type, func) od_id = id(opdef) @@ -76,15 +77,15 @@ def create(func, *args, **kwargs): opdef._nativedef.op_type = op_type opdef._nativedef.obj_id = od_id - inputs = kwargs.get('inputs', None) + inputs = kwargs.get("inputs", None) if inputs is None: inputs = [PyCustomOpDef.dt_float] opdef._nativedef.input_types = inputs - outputs = kwargs.get('outputs', None) + outputs = kwargs.get("outputs", None) if outputs is None: outputs = [PyCustomOpDef.dt_float] opdef._nativedef.output_types = outputs - attrs = kwargs.get('attrs', None) + attrs = kwargs.get("attrs", None) if attrs is None: attrs = {} elif isinstance(attrs, (list, tuple)): @@ -106,16 +107,15 @@ def cast_attributes(self, attributes): elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string: res[k] = v else: - raise RuntimeError("Unsupported attribute type {}.".format( - self._nativedef.attrs[k])) + raise RuntimeError("Unsupported attribute type {}.".format(self._nativedef.attrs[k])) return res def _on_pyop_invocation(k_id, feed, attributes): if k_id not in Opdef._odlist: raise RuntimeError( - "Unable to find function id={}. " - "Did you decorate the operator with @onnx_op?.".format(k_id)) + "Unable to find function id={}. " "Did you decorate the operator with @onnx_op?.".format(k_id) + ) op_ = Opdef._odlist[k_id] rv = op_.body(*feed, **op_.cast_attributes(attributes)) if isinstance(rv, tuple): @@ -127,86 +127,7 @@ def _on_pyop_invocation(k_id, feed, attributes): res = tuple(res) else: res = (rv.shape, rv.flatten().tolist()) - return (k_id, ) + res - - -def _ensure_opset_domain(model): - op_domain_name = default_opset_domain() - domain_missing = True - for oi_ in model.opset_import: - if oi_.domain == op_domain_name: - domain_missing = False - - if domain_missing: - model.opset_import.extend( - [helper.make_operatorsetid(op_domain_name, 1)]) - - return model - - -def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs): - """ - Replace the existing inputs of a model with the new inputs, plus some extra nodes - :param model: The ONNX model loaded as ModelProto - :param target_input: The input name to be replaced - :param extra_nodes: The extra nodes to be added - :param new_inputs: The new input (type: ValueInfoProto) sequence - :return: The ONNX model after modification - """ - graph = model.graph - new_inputs = [n for n in graph.input if n.name != - target_input] + new_inputs - new_nodes = list(model.graph.node) + extra_nodes - new_graph = helper.make_graph( - new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer)) - - new_model = copy.deepcopy(model) - new_model.graph.CopyFrom(new_graph) - - return _ensure_opset_domain(new_model) - - -def hook_model_op(model, node_name, hook_func, input_types): - """ - Add a hook function node in the ONNX Model, which could be used for the model diagnosis. - :param model: The ONNX model loaded as ModelProto - :param node_name: The node name where the hook will be installed - :param hook_func: The hook function, callback on the model inference - :param input_types: The input types as a list - :return: The ONNX model with the hook installed - """ - - # onnx.shape_inference is very unstable, useless. - # hkd_model = shape_inference.infer_shapes(model) - hkd_model = model - - n_idx = 0 - hnode, nnode = (None, None) - nodes = list(hkd_model.graph.node) - brkpt_name = node_name + '_hkd' - optype_name = "op_{}_{}".format(hook_func.__name__, node_name) - for n_ in nodes: - if n_.name == node_name: - input_names = list(n_.input) - brk_output_name = [i_ + '_hkd' for i_ in input_names] - hnode = onnx.helper.make_node( - optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()) - nnode = n_ - del nnode.input[:] - nnode.input.extend(brk_output_name) - break - n_idx += 1 - - if hnode is None: - raise ValueError("{} is not an operator node name".format(node_name)) - - repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:] - del hkd_model.graph.node[:] - hkd_model.graph.node.extend(repacked) - - Opdef.create(hook_func, op_type=optype_name, - inputs=input_types, outputs=input_types) - return _ensure_opset_domain(hkd_model) + return (k_id,) + res PyCustomOpDef.install_hooker(_on_pyop_invocation) diff --git a/onnxruntime_extensions/_ortapi2.py b/onnxruntime_extensions/_ortapi2.py index 425aa26a8..b01681347 100644 --- a/onnxruntime_extensions/_ortapi2.py +++ b/onnxruntime_extensions/_ortapi2.py @@ -7,8 +7,9 @@ _ortapi2.py: ONNXRuntime-Extensions Python API """ +import copy import numpy as np -from ._ocos import default_opset_domain, get_library_path # noqa +from ._ocos import default_opset_domain, get_library_path, Opdef from ._cuops import onnx, onnx_proto, SingleOpGraph _ort_check_passed = False @@ -25,6 +26,82 @@ raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0") +def _ensure_opset_domain(model): + op_domain_name = default_opset_domain() + domain_missing = True + for oi_ in model.opset_import: + if oi_.domain == op_domain_name: + domain_missing = False + + if domain_missing: + model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)]) + + return model + + +def hook_model_op(model, node_name, hook_func, input_types): + """ + Add a hook function node in the ONNX Model, which could be used for the model diagnosis. + :param model: The ONNX model loaded as ModelProto + :param node_name: The node name where the hook will be installed + :param hook_func: The hook function, callback on the model inference + :param input_types: The input types as a list + :return: The ONNX model with the hook installed + """ + + # onnx.shape_inference is very unstable, useless. + # hkd_model = shape_inference.infer_shapes(model) + hkd_model = model + + n_idx = 0 + hnode, nnode = (None, None) + nodes = list(hkd_model.graph.node) + brkpt_name = node_name + "_hkd" + optype_name = "op_{}_{}".format(hook_func.__name__, node_name) + for n_ in nodes: + if n_.name == node_name: + input_names = list(n_.input) + brk_output_name = [i_ + "_hkd" for i_ in input_names] + hnode = onnx.helper.make_node( + optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain() + ) + nnode = n_ + del nnode.input[:] + nnode.input.extend(brk_output_name) + break + n_idx += 1 + + if hnode is None: + raise ValueError("{} is not an operator node name".format(node_name)) + + repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :] + del hkd_model.graph.node[:] + hkd_model.graph.node.extend(repacked) + + Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types) + return _ensure_opset_domain(hkd_model) + + +def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs): + """ + Replace the existing inputs of a model with the new inputs, plus some extra nodes + :param model: The ONNX model loaded as ModelProto + :param target_input: The input name to be replaced + :param extra_nodes: The extra nodes to be added + :param new_inputs: The new input (type: ValueInfoProto) sequence + :return: The ONNX model after modification + """ + graph = model.graph + new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs + new_nodes = list(model.graph.node) + extra_nodes + new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer)) + + new_model = copy.deepcopy(model) + new_model.graph.CopyFrom(new_graph) + + return _ensure_opset_domain(new_model) + + def get_opset_version_from_ort(): _ORT_OPSET_SUPPORT_TABLE = { "1.5": 11, @@ -37,10 +114,10 @@ def get_opset_version_from_ort(): "1.12": 17, "1.13": 17, "1.14": 18, - "1.15": 18 + "1.15": 18, } - ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) + ort_ver_string = ".".join(_ort.__version__.split(".")[0:2]) max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get) if ort_ver_string > max_ver: ort_ver_string = max_ver @@ -50,19 +127,18 @@ def get_opset_version_from_ort(): def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1): if opset_version == 0: opset_version = get_opset_version_from_ort() - fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version' - ) else onnx.helper.make_model - model = fn_mm(graph, opset_imports=[ - onnx.helper.make_operatorsetid('ai.onnx', opset_version)]) - model.opset_import.extend( - [onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)]) + fn_mm = ( + onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model + ) + model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)]) + model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)]) return model class OrtPyFunction: """ OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession, - equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a + equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a standard Python function. The order of the function arguments correlates directly with the sequence of the input/output in the ONNX graph. """ @@ -78,10 +154,10 @@ def __init__(self, path_or_model=None, cpu_only=None): self._onnx_model = None self.ort_session = None self.default_inputs = {} - self.execution_providers = ['CPUExecutionProvider'] + self.execution_providers = ["CPUExecutionProvider"] if not cpu_only: - if _ort.get_device() == 'GPU': - self.execution_providers = ['CUDAExecutionProvider'] + if _ort.get_device() == "GPU": + self.execution_providers = ["CUDAExecutionProvider"] self.extra_session_options = {} mpath = None if isinstance(path_or_model, str): @@ -99,8 +175,8 @@ def create_from_customop(self, op_type, *args, **kwargs): def add_default_input(self, **kwargs): inputs = { - ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else - np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items() + ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8) + for ky_, val_ in kwargs.items() } self.default_inputs.update(inputs) @@ -124,30 +200,29 @@ def _bind(self, oxml, model_path=None): self._oxml = oxml if model_path is not None: self.ort_session = _ort.InferenceSession( - model_path, self.get_ort_session_options(), - self.execution_providers) + model_path, self.get_ort_session_options(), self.execution_providers + ) return self def _ensure_ort_session(self): if self.ort_session is None: sess = _ort.InferenceSession( - self.onnx_model.SerializeToString(), self.get_ort_session_options(), - self.execution_providers) + self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers + ) self.ort_session = sess return self.ort_session @staticmethod def _get_kwarg_device(kwargs): - cpuonly = kwargs.get('cpu_only', None) + cpuonly = kwargs.get("cpu_only", None) if cpuonly is not None: - del kwargs['cpu_only'] + del kwargs["cpu_only"] return cpuonly @classmethod def from_customop(cls, op_type, *args, **kwargs): - return (cls(cpu_only=cls._get_kwarg_device(kwargs)) - .create_from_customop(op_type, *args, **kwargs)) + return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs) @classmethod def from_model(cls, path_or_model, *args, **kwargs): @@ -165,9 +240,9 @@ def _argument_map(self, *args, **kwargs): x = args[idx] ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x # numpy by default is int32 in some platforms, sometimes it is int64. - feed[i_.name] = \ - ts_x.astype( - np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x + feed[i_.name] = ( + ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x + ) idx += 1 feed.update(kwargs) @@ -175,8 +250,7 @@ def _argument_map(self, *args, **kwargs): def __call__(self, *args, **kwargs): self._ensure_ort_session() - outputs = self.ort_session.run( - None, self._argument_map(*args, **kwargs)) + outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs)) return outputs[0] if len(outputs) == 1 else tuple(outputs) @@ -191,8 +265,9 @@ def optimize_model(model_or_file, output_file): sess_options = OrtPyFunction().get_ort_session_options() sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC sess_options.optimized_model_filepath = output_file - _ort.InferenceSession(model_or_file if isinstance(model_or_file, str) - else model_or_file.SerializeToString(), sess_options) + _ort.InferenceSession( + model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options + ) ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail diff --git a/requirements-dev.txt b/requirements-dev.txt index c1667bc63..47114f13f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,7 @@ -# include requirements.txt so pip has context to avoid installing incompatible dependencies --r requirements.txt pytest -# multiple versions of onnxruntime are supported, but only one can be installed at a time +onnx >= 1.9.0 protobuf < 4.0.0 +# multiple versions of onnxruntime are supported, but only one can be installed at a time onnxruntime >=1.12.0 transformers >=4.9.2 tensorflow_text >=2.5.0;python_version < '3.11' diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ccbcb38bd..000000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -onnx>=1.9.0 diff --git a/setup.py b/setup.py index c50816431..3b2100e98 100644 --- a/setup.py +++ b/setup.py @@ -9,14 +9,13 @@ import pathlib import setuptools -from textwrap import dedent from setuptools import setup, find_packages TOP_DIR = os.path.dirname(__file__) or os.getcwd() -PACKAGE_NAME = 'onnxruntime_extensions' +PACKAGE_NAME = "onnxruntime_extensions" # setup.py cannot be debugged in pip command line, so the command classes are refactored into another file -cmds_dir = pathlib.Path(TOP_DIR) / '.pyproject' +cmds_dir = pathlib.Path(TOP_DIR) / ".pyproject" sys.path.append(str(cmds_dir)) # noinspection PyUnresolvedReferences import cmdclass as _cmds # noqa: E402 @@ -24,62 +23,50 @@ _cmds.prepare_env(TOP_DIR) -def read_requirements(): - with open(os.path.join(TOP_DIR, "requirements.txt"), "r", encoding="utf-8") as f: - requirements = [_ for _ in [dedent(_) for _ in f.readlines()] if _ is not None] - return requirements - - # read version from the package file. def read_version(): - version_str = '1.0.0' - with (open(os.path.join(TOP_DIR, 'version.txt'), "r")) as f: + version_str = "0.1.0" + with open(os.path.join(TOP_DIR, "version.txt"), "r") as f: version_str = f.readline().strip() # special handling for Onebranch building - if os.getenv('BUILD_SOURCEBRANCHNAME', "").startswith('rel-'): + if os.getenv("BUILD_SOURCEBRANCHNAME", "").startswith("rel-"): return version_str # is it a dev build or release? - rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir( - os.path.join(TOP_DIR, '.git')) else (True, None) + rel_br, cid = _cmds.read_git_refs(TOP_DIR) if os.path.isdir(os.path.join(TOP_DIR, ".git")) else (True, None) if rel_br: return version_str - build_id = os.getenv('BUILD_BUILDID', None) + build_id = os.getenv("BUILD_BUILDID", None) if build_id is not None: - version_str += '.{}'.format(build_id) + version_str += ".{}".format(build_id) else: - version_str += '+' + cid[:7] + version_str += "+" + cid[:7] return version_str def write_py_version(ext_version): - text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", - "__version__ = \"{}\"\n".format(ext_version)] - with (open(os.path.join(TOP_DIR, 'onnxruntime_extensions/_version.py'), "w")) as _fver: + text = ["# Generated by setup.py, DON'T MANUALLY UPDATE IT!\n", '__version__ = "{}"\n'.format(ext_version)] + with open(os.path.join(TOP_DIR, "onnxruntime_extensions/_version.py"), "w") as _fver: _fver.writelines(text) -ext_modules = [ - setuptools.extension.Extension( - name=str('onnxruntime_extensions._extensions_pydll'), - sources=[]) -] +ext_modules = [setuptools.extension.Extension(name=str("onnxruntime_extensions._extensions_pydll"), sources=[])] packages = find_packages() -package_dir = {k: os.path.join('.', k.replace(".", "/")) for k in packages} +package_dir = {k: os.path.join(".", k.replace(".", "/")) for k in packages} package_data = { "onnxruntime_extensions": ["*.so", "*.pyd"], } -long_description = '' -with open(os.path.join(TOP_DIR, "README.md"), 'r', encoding="utf-8") as _f: +long_description = "" +with open(os.path.join(TOP_DIR, "README.md"), "r", encoding="utf-8") as _f: long_description += _f.read() - start_pos = long_description.find('# Introduction') + start_pos = long_description.find("# Introduction") start_pos = 0 if start_pos < 0 else start_pos - end_pos = long_description.find('# Contributing') + end_pos = long_description.find("# Contributing") long_description = long_description[start_pos:end_pos] ortx_version = read_version() write_py_version(ortx_version) @@ -92,25 +79,25 @@ def write_py_version(ext_version): package_data=package_data, description="ONNXRuntime Extensions", long_description=long_description, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='onnxruntime@microsoft.com', - url='https://github.com/microsoft/onnxruntime-extensions', + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="onnxruntime@microsoft.com", + url="https://github.com/microsoft/onnxruntime-extensions", ext_modules=ext_modules, cmdclass=_cmds.ortx_cmdclass, include_package_data=True, - install_requires=read_requirements(), + install_requires=[], classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX :: Linux', + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", "Programming Language :: C++", - 'Programming Language :: Python', + "Programming Language :: Python", "Programming Language :: Python :: Implementation :: CPython", - 'License :: OSI Approved :: MIT License' - ] + "License :: OSI Approved :: MIT License", + ], ) From f85c76bbf1ad9f0cc59f79b0f0adea89a84a4867 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Wed, 14 Feb 2024 23:28:10 +0000 Subject: [PATCH 2/3] update the ci.yml --- .pipelines/ci.yml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.pipelines/ci.yml b/.pipelines/ci.yml index 31dd0c0af..d84c9c5a8 100644 --- a/.pipelines/ci.yml +++ b/.pipelines/ci.yml @@ -89,8 +89,7 @@ stages: python -m pip install --upgrade pip python -m pip install --upgrade setuptools python -m pip install onnxruntime==$(ort.version) - python -m pip install -r requirements.txt - displayName: Install requirements.txt + displayName: Install requirements - script: | CPU_NUMBER=8 python -m pip install . @@ -283,8 +282,7 @@ stages: python -m pip install --upgrade setuptools python -m pip install --upgrade wheel python -m pip install onnxruntime==$(ort.version) - python -m pip install -r requirements.txt - displayName: Install requirements.txt + displayName: Install requirements - script: | python -c "import onnxruntime;print(onnxruntime.__version__)" @@ -419,7 +417,6 @@ stages: call activate pyenv python -m pip install --upgrade pip python -m pip install onnxruntime==$(ort.version) - python -m pip install -r requirements.txt python -m pip install -r requirements-dev.txt displayName: Install requirements{-dev}.txt and cmake python modules @@ -653,7 +650,6 @@ stages: python3 -m pip install --upgrade pip; \ python3 -m pip install --upgrade setuptools; \ python3 -m pip install onnxruntime-gpu==$(ORT_VERSION); \ - python3 -m pip install -r requirements.txt; \ python3 -m pip install -v --config-settings "ortx-user-option=use-cuda" . ; \ python3 -m pip install $(TORCH_VERSION) ; \ python3 -m pip install -r requirements-dev.txt; \ From 6a05ea5bcb98c008b35c179834aa3dd6a446fb0b Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Thu, 15 Feb 2024 12:00:01 -0800 Subject: [PATCH 3/3] add more message of missing ONNX package --- README.md | 9 ++++++--- onnxruntime_extensions/__init__.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d601ef7fd..40091f842 100644 --- a/README.md +++ b/README.md @@ -31,12 +31,15 @@ python -m pip install git+https://github.com/microsoft/onnxruntime-extensions.gi ## Usage -## 1. Generate the pre-/post- processing ONNX model -With onnxruntime-extensions Python package, you can easily get the ONNX processing graph by converting them from Huggingface transformer data processing classes, check the following API for details. +## 1. Generation of Pre-/Post-Processing ONNX Model +The `onnxruntime-extensions` Python package provides a convenient way to generate the ONNX processing graph. This can be achieved by converting the Huggingface transformer data processing classes into the desired format. For more detailed information, please refer to the API below: + ```python help(onnxruntime_extensions.gen_processing_models) ``` -### NOTE: These data processing model can be merged into other model [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed. +### NOTE: +The generation of model processing requires the **ONNX** package to be installed. The data processing models generated in this manner can be merged with other models using the [onnx.compose](https://onnx.ai/onnx/api/compose.html) if needed. + ## 2. Using Extensions for ONNX Runtime inference ### Python diff --git a/onnxruntime_extensions/__init__.py b/onnxruntime_extensions/__init__.py index 979d6f883..872c5a2b5 100644 --- a/onnxruntime_extensions/__init__.py +++ b/onnxruntime_extensions/__init__.py @@ -58,7 +58,16 @@ PyOp = PyCustomOpDef -if not _lib_only: +if _lib_only: + + def _unimplemented(*args, **kwargs): + raise NotImplementedError("ONNX or ONNX Runtime is not installed") + + gen_processing_models = _unimplemented + OrtPyFunction = _unimplemented + ort_inference = _unimplemented + +else: __all__ += _offline_api from ._cuops import * # noqa