Skip to content

Commit

Permalink
fix quantization for onnxruntime v1.16.0 (#1405)
Browse files Browse the repository at this point in the history
* fix quantization for onnxruntime v1.16.0

* Update optimum/onnxruntime/quantization.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/onnxruntime/quantization.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/onnxruntime/quantization.py

Co-authored-by: fxmarty <[email protected]>

* skip test for ort v1.16.0

---------

Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
echarlaix and fxmarty authored Sep 21, 2023
1 parent 8383fb3 commit 3c4ad78
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
13 changes: 11 additions & 2 deletions optimum/onnxruntime/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def compute_ranges(self) -> Dict[str, Tuple[float, float]]:
)

LOGGER.info("Computing calibration ranges")

if parse(ort_version) >= Version("1.16.0"):
return self._calibrator.compute_data()

return self._calibrator.compute_range()

def quantize(
Expand Down Expand Up @@ -351,8 +355,13 @@ def quantize(
has_subgraphs = True
break

if quantization_config.is_static and has_subgraphs:
raise NotImplementedError("Static quantization is currently not supported for models with" " subgraphs.")
if has_subgraphs:
if quantization_config.is_static:
raise NotImplementedError("Static quantization is currently not supported for models with subgraphs.")
if parse(ort_version) == Version("1.16.0"):
raise ValueError(
"ONNX Runtime version v1.16.0 is not compatible with quantization for models with subgraphs, please downgrade to 1.15.1 or upgrade to a higher version. Reference: https://github.com/microsoft/onnxruntime/pull/17651"
)

quantizer_factory = QDQQuantizer if use_qdq else ONNXQuantizer

Expand Down
15 changes: 13 additions & 2 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import unittest
from pathlib import Path

from onnxruntime import __version__ as ort_version
from packaging.version import Version, parse

import optimum.commands


Expand Down Expand Up @@ -84,14 +87,22 @@ def test_quantize_commands(self):
export_commands = [
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
# f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
]
quantize_commands = [
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder --avx2 -o {tempdir}/quantized_encoder",
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/decoder --avx2 -o {tempdir}/quantized_decoder",
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder",
# f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder",
]

if parse(ort_version) != Version("1.16.0"):
export_commands.append(
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder"
)
quantize_commands.append(
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder"
)

for export, quantize in zip(export_commands, quantize_commands):
subprocess.run(export, shell=True, check=True)
subprocess.run(quantize, shell=True, check=True)
Expand Down
4 changes: 3 additions & 1 deletion tests/onnxruntime/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from pathlib import Path

from onnx import load as onnx_load
from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType
from packaging.version import Version, parse
from parameterized import parameterized
from transformers import AutoTokenizer

Expand Down Expand Up @@ -112,9 +114,9 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma
self.assertEqual(expected_quantized_matmuls, num_quantized_matmul)
gc.collect()

@unittest.skipIf(parse(ort_version) == Version("1.16.0"), "not supported with this onnxruntime version")
def test_dynamic_quantization_subgraphs(self):
qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True)
# with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = tempfile.mkdtemp()
output_dir = Path(tmp_dir)
model = ORTModelForCausalLM.from_pretrained(
Expand Down

0 comments on commit 3c4ad78

Please sign in to comment.