Skip to content

Commit

Permalink
RunONNXModel.py: Add a --cache-model=path option (#2984)
Browse files Browse the repository at this point in the history
* added cache

Signed-off-by: Alexandre Eichenberger <[email protected]>

---------

Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Oct 25, 2024
1 parent 91a72e6 commit 0aa652f
Showing 1 changed file with 51 additions and 24 deletions.
75 changes: 51 additions & 24 deletions utils/RunONNXModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,59 +63,59 @@ def check_non_negative(argname, value):
nargs="?",
const="compilation.log",
default=None,
help="Output compilation messages to file, default compilation.log",
help="Output compilation messages to file, default compilation.log.",
)
parser.add_argument(
"-m",
"--model",
type=lambda s: valid_onnx_input(s),
help="Path to an ONNX model (.onnx or .mlir)",
help="Path to an ONNX model (.onnx or .mlir).",
)
parser.add_argument(
"-c",
"--compile-args",
type=str,
default="",
help="Arguments passed directly to onnx-mlir command." " See bin/onnx-mlir --help",
help="Arguments passed directly to onnx-mlir command." " See bin/onnx-mlir --help.",
)
parser.add_argument(
"-C", "--compile-only", action="store_true", help="Only compile the input model"
"-C", "--compile-only", action="store_true", help="Only compile the input model."
)
parser.add_argument(
"--compile-using-input-shape",
action="store_true",
help="Compile the model by using the shape info getting from"
" the inputs in the reference folder set by --load-ref",
)
parser.add_argument("--print-input", action="store_true", help="Print out inputs")
parser.add_argument("--print-input", action="store_true", help="Print out inputs.")
parser.add_argument(
"--print-output",
action="store_true",
help="Print out inference outputs produced by onnx-mlir",
help="Print out inference outputs produced by onnx-mlir.",
)
parser.add_argument(
"--print-signatures",
action="store_true",
help="Print out the input and output signatures of the model",
help="Print out the input and output signatures of the model.",
)
parser.add_argument(
"--save-onnx",
metavar="PATH",
type=str,
help="File path to save the onnx model. Only effective if " "--verify=onnxruntime",
help="File path to save the onnx model. Only effective if --verify=onnxruntime.",
)
parser.add_argument(
"--verify",
choices=["onnxruntime", "ref"],
help="Verify the output by using onnxruntime or reference"
" inputs/outputs. By default, no verification. When being"
" enabled, --verify-with-softmax or --verify-every-value"
" must be used to specify verification mode",
" must be used to specify verification mode.",
)
parser.add_argument(
"--verify-all-ops",
action="store_true",
help="Verify all operation outputs when using onnxruntime",
help="Verify all operation outputs when using onnxruntime.",
)
parser.add_argument(
"--verify-with-softmax",
Expand All @@ -129,46 +129,53 @@ def check_non_negative(argname, value):
parser.add_argument(
"--verify-every-value",
action="store_true",
help="Verify every value of the output using atol and rtol",
help="Verify every value of the output using atol and rtol.",
)
parser.add_argument(
"--rtol", type=str, default="0.05", help="Relative tolerance for verification"
"--rtol", type=str, default="0.05", help="Relative tolerance for verification."
)
parser.add_argument(
"--atol", type=str, default="0.01", help="Absolute tolerance for verification"
"--atol", type=str, default="0.01", help="Absolute tolerance for verification."
)

lib_group = parser.add_mutually_exclusive_group()
lib_group.add_argument(
"--save-model",
metavar="PATH",
type=str,
help="Path to a folder to save the compiled model",
help="Path to a folder to save the compiled model.",
)
lib_group.add_argument(
"--load-model",
metavar="PATH",
type=str,
help="Path to a folder to load a compiled model for "
"inference, and the ONNX model will not be re-compiled",
"inference, and the ONNX model will not be re-compiled.",
)
lib_group.add_argument(
"--cache-model",
metavar="PATH",
type=str,
help="When finding a compiled model in given path, reuse it. "
"Otherwise, compile model and save it into the given path.",
)

parser.add_argument(
"--save-ref",
metavar="PATH",
type=str,
help="Path to a folder to save the inputs and outputs in protobuf",
help="Path to a folder to save the inputs and outputs in protobuf.",
)
data_group = parser.add_mutually_exclusive_group()
data_group.add_argument(
"--load-ref",
metavar="PATH",
type=str,
help="Path to a folder containing reference inputs and outputs stored in protobuf."
" If --verify=ref, inputs and outputs are reference data for verification",
" If --verify=ref, inputs and outputs are reference data for verification.",
)
data_group.add_argument(
"--inputs-from-arrays", help="List of numpy arrays used as inputs for inference"
"--inputs-from-arrays", help="List of numpy arrays used as inputs for inference."
)
data_group.add_argument(
"--load-ref-from-numpy",
Expand All @@ -182,7 +189,7 @@ def check_non_negative(argname, value):
"--shape-info",
type=str,
help="Shape for each dynamic input of the model, e.g. 0:1x10x20,1:7x5x3. "
"Used to generate random inputs for the model if --load-ref is not set",
"Used to generate random inputs for the model if --load-ref is not set.",
)

parser.add_argument(
Expand All @@ -191,38 +198,39 @@ def check_non_negative(argname, value):
help="Lower bound values for each data type. Used inputs."
" E.g. --lower-bound=int64:-10,float32:-0.2,uint8:1."
" Supported types are bool, uint8, int8, uint16, int16, uint32, int32,"
" uint64, int64,float16, float32, float64",
" uint64, int64,float16, float32, float64.",
)
parser.add_argument(
"--upper-bound",
type=str,
help="Upper bound values for each data type. Used to generate random inputs."
" E.g. --upper-bound=int64:10,float32:0.2,uint8:9."
" Supported types are bool, uint8, int8, uint16, int16, uint32, int32,"
" uint64, int64, float16, float32, float64",
" uint64, int64, float16, float32, float64.",
)
parser.add_argument(
"-w",
"--warmup",
type=lambda s: check_non_negative("--warmup", s),
default=0,
help="The number of warmup inference runs",
help="The number of warmup inference runs.",
)
parser.add_argument(
"-n",
"--n-iteration",
type=lambda s: check_positive("--n-iteration", s),
default=1,
help="The number of inference runs excluding warmup",
help="The number of inference runs excluding warmup.",
)
parser.add_argument(
"--seed",
type=str,
default="42",
help="seed to initialize the random num generator for inputs",
help="seed to initialize the random num generator for inputs.",
)

args = parser.parse_args()

if args.verify and (args.verify_with_softmax is None) and (not args.verify_every_value):
raise RuntimeError(
"Choose verification mode: --verify-with-softmax or "
Expand Down Expand Up @@ -639,6 +647,25 @@ def __init__(self, model_file, **kwargs):
# e.g. model.so, model.constants.bin, ...
self.default_model_name = "model"

# Handle cache_model.
if args.cache_model:
shared_lib_path = args.cache_model + f"/{self.default_model_name}.so"
if not os.path.exists(shared_lib_path):
print(
'Cached compiled model not found in "'
+ args.cache_model
+ '": save model this run.'
)
args.save_model = args.cache_model
else:
print(
'Cached compiled model found in "'
+ args.cache_model
+ '": load model this run.'
)
args.load_model = args.cache_model
args.cache_model = None

# Get shape information if given.
# args.shape_info in the form of 'input_index:d1xd2, input_index:d1xd2'
input_shapes = {}
Expand Down

0 comments on commit 0aa652f

Please sign in to comment.