Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable aimet-onnx acceptance tests #3305

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions NightlyTests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ if (ENABLE_TENSORFLOW)
AcceptanceTests.Tensorflow)
endif (ENABLE_TENSORFLOW)

if (ENABLE_ONNX)
add_dependencies(AcceptanceTests
AcceptanceTests.ONNX)
endif(ENABLE_ONNX)

if (ENABLE_TORCH AND NOT ENABLE_ONNX)
add_subdirectory(torch)
endif (ENABLE_TORCH AND NOT ENABLE_ONNX)

if (ENABLE_TENSORFLOW)
add_subdirectory(tensorflow)
endif (ENABLE_TENSORFLOW)

if (ENABLE_ONNX)
add_subdirectory(onnx)
endif(ENABLE_ONNX)
26 changes: 12 additions & 14 deletions NightlyTests/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2018, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2018-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -36,17 +36,15 @@
#
#=============================================================================

if (ENABLE_CUDA)
set(CUDA_FLAG "not blah")
set(USE_CUDA True)
else (ENABLE_CUDA)
set(CUDA_FLAG "not cuda")
set(USE_CUDA False)
endif (ENABLE_CUDA)

add_custom_target( AcceptanceTests.ONNX )

file(GLOB files "${CMAKE_CURRENT_SOURCE_DIR}/test_*.py")
message("Start to find tests in ONNX")
foreach(filename ${files})
message("Find: " ${filename})
get_filename_component( testname "${filename}" NAME_WE )
add_custom_target(AcceptanceTests.ONNX.${testname}
VERBATIM COMMAND ${CMAKE_COMMAND} -E env
"${AIMET_PYTHONPATH}"
pytest -s ${filename} --junitxml=${CMAKE_CURRENT_BINARY_DIR}/py_test_output_${testname}.xml)

endforeach( filename )
add_custom_target(AcceptanceTests.ONNX
VERBATIM COMMAND ${CMAKE_COMMAND} -E env
"${AIMET_PYTHONPATH}"
${Python3_EXECUTABLE} -m pytest -s ${CMAKE_CURRENT_SOURCE_DIR} -m ${CUDA_FLAG} --junitxml=${CMAKE_CURRENT_BINARY_DIR}/py_test_output.xml)
45 changes: 21 additions & 24 deletions NightlyTests/onnx/test_adaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,8 +34,8 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import os
from packaging import version
import json
import numpy as np
import pytest
Expand All @@ -45,12 +45,9 @@
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from torchvision import models

from aimet_onnx.utils import make_dummy_input
from aimet_common.defs import QuantScheme
from aimet_onnx.quantsim import QuantizationSimModel
from torch_utils import get_cifar10_data_loaders, train_cifar10
from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession

from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters

image_size = 32
Expand Down Expand Up @@ -81,33 +78,32 @@ class TestAdaroundAcceptance:
""" Acceptance test for AIMET ONNX """
@pytest.mark.cuda
def test_adaround(self):
if version.parse(torch.__version__) >= version.parse("1.13"):
np.random.seed(0)
torch.manual_seed(0)

model = get_model()

data_loader = dataloader()
dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
sess = build_session(model)
out_before_ada = sess.run(None, dummy_input)
def callback(session, args):
in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
session.run(None, in_tensor)

params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback,
forward_pass_callback_args=None)
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy')
np.random.seed(0)
torch.manual_seed(0)
model = get_model()
data_loader = dataloader()
dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
sess = build_session(model)
out_before_ada = sess.run(None, dummy_input)
def callback(session, args):
in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)}
session.run(None, in_tensor)

params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback,
forward_pass_callback_args=None)

with tempfile.TemporaryDirectory() as tmpdir:
ada_rounded_model = Adaround.apply_adaround(model, params, tmpdir, 'dummy')
sess = build_session(ada_rounded_model)
out_after_ada = sess.run(None, dummy_input)
assert not np.array_equal(out_before_ada[0], out_after_ada[0])

with open('./dummy.encodings') as json_file:
with open(os.path.join(tmpdir, 'dummy.encodings')) as json_file:
encoding_data = json.load(json_file)

sim = QuantizationSimModel(ada_rounded_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8,
default_activation_bw=8, use_cuda=True)
sim.set_and_freeze_param_encodings('./dummy.encodings')
sim.set_and_freeze_param_encodings(os.path.join(tmpdir, 'dummy.encodings'))
sim.compute_encodings(callback, None)
assert sim.qc_quantize_op_dict['fc.weight'].encodings[0].delta == encoding_data['fc.weight'][0]['scale']

Expand Down Expand Up @@ -151,6 +147,7 @@ def __len__(self):
dummy_dataloader = DataLoader(batch_size=2)
return dummy_dataloader


def build_session(model):
"""
Build and return onnxruntime inference session
Expand Down
4 changes: 3 additions & 1 deletion NightlyTests/onnx/test_cross_layer_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,6 +34,7 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import numpy as np
import pytest

Expand All @@ -43,6 +44,7 @@

class TestCLEAcceptance:
""" Acceptance test for AIMET ONNX """
@pytest.mark.skip(reason="Find better test criteria.")
@pytest.mark.parametrize('model', [test_models.mobilenetv2(), test_models.mobilenetv3_large_model()])
def test_cle_mv2(self, model):
""" Test for E2E quantization """
Expand Down
7 changes: 0 additions & 7 deletions NightlyTests/onnx/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@

""" Test top level mixed precision api """

import os
import tempfile

import pytest
import unittest
import unittest.mock

import json
import numpy as np
from test_models import resnet18
from aimet_onnx.mixed_precision import choose_mixed_precision
Expand All @@ -59,8 +55,6 @@

class TestMixedPrecision:
""" Test case for mixed precision api """


@pytest.mark.cuda
def test_quantize_with_mixed_precision(self):
""" Test top level quantize_with_mixed_precision api """
Expand Down Expand Up @@ -147,7 +141,6 @@ def __iter__(self):
pareto_eval_scores = [eval_score for _, eval_score, _, _ in pareto_front_list]
assert eval_score in pareto_eval_scores


def forward_pass_callback(session, inp_shape):
""" Call mnist_evaluate setting use_cuda to True, iterations=5 """
inputs = np.random.rand(*inp_shape).astype(np.float32)
Expand Down
21 changes: 12 additions & 9 deletions NightlyTests/onnx/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2022-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,8 +34,8 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
import os

import os
import numpy as np
import pytest
import tempfile
Expand All @@ -44,10 +44,14 @@
from torchvision import models

from aimet_onnx.utils import make_dummy_input
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.defs import QuantScheme
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from torch_utils import get_cifar10_data_loaders, train_cifar10
try:
from torch_utils import get_cifar10_data_loaders, train_cifar10
except (ImportError, OSError):
pass
# TODO (hitameht): For onnx-cpu variant, fix OSError: libtorch_hip.so: cannot open shared object file: No such file or directory

image_size = 32
batch_size = 64
Expand Down Expand Up @@ -95,13 +99,14 @@ def test_quantized_accuracy(self, config_file):
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'},
}
},
opset_version = 12,
)

onnx_model = load_model(os.path.join(tmp_dir, 'resnet18.onnx'))
dummy_input = make_dummy_input(onnx_model)
sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8,
default_activation_bw=8, use_cuda=True, config_file=config_file)
sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf,
default_param_bw=8, default_activation_bw=8, use_cuda=True, config_file=config_file)

def onnx_callback(session, iters):
for i, batch in enumerate(train_loader):
Expand All @@ -112,7 +117,5 @@ def onnx_callback(session, iters):
break

sim.compute_encodings(onnx_callback, 10)

onnx_qs_acc = model_eval_onnx(sim.session, val_loader)

assert onnx_qs_acc > 0.5
13 changes: 9 additions & 4 deletions NightlyTests/onnx/test_rnn_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,24 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
import os

import os
import numpy as np
import pytest
import tempfile
import torch
from onnx import load_model
from torchaudio import models

from aimet_onnx.utils import make_dummy_input
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.defs import QuantScheme
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from torch_utils import get_librispeech_data_loaders, train_librispeech
try:
from torch_utils import get_librispeech_data_loaders, train_librispeech
from torchaudio import models
except (ImportError, OSError):
pass
# TODO (hitameht): For onnx-cpu variant, fix OSError: libtorch_hip.so: cannot open shared object file: No such file or directory

batch_size = 64
n_feature = 128
Expand Down Expand Up @@ -85,6 +89,7 @@ class TestQuantizeAcceptance:
""" Acceptance test for AIMET ONNX """
@pytest.mark.parametrize("config_file", [None, get_path_for_per_channel_config()])
@pytest.mark.cuda
@pytest.mark.skip(reason="Figure out a way to download datasets.")
def test_quantized_accuracy(self, config_file):
with tempfile.TemporaryDirectory() as tmp_dir:
np.random.seed(0)
Expand Down
1 change: 0 additions & 1 deletion NightlyTests/onnx/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def get_librispeech_data_loaders(batch_size=64, num_workers=4, drop_last=True):
return train_loader, val_loader



def model_train(model: torch.nn.Module, train_loader: DataLoader, epochs: int, optimizer: optim.Optimizer, scheduler):
"""
Trains the given torch model for the specified number of epochs
Expand Down
Loading