Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Onnx node tests #7720

Merged
merged 17 commits into from
Mar 24, 2021
Merged

[ONNX] Onnx node tests #7720

merged 17 commits into from
Mar 24, 2021

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Mar 22, 2021

We've been hitting a lot of errors running models with ONNX, that has led to a lot of piecewise fixes. This is an attempt to fix the importer more broadly by running the tests onnx ships with pip https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node

These files contain an onnx graph, input arrays, and expected outputs, so we can test directly against the canonical onnx tests. This PR provides a method to import these tests as parameterized unit tests, execute them, and skip any we know currently fail. I also fixed a lot of low hanging fruit to reduce the number of skipped unit tests.

Future PRs will work to fix the currently skipped tests, and then extend this to GPU.

For reference, this is the pytest result on my system, testing against ONNX 1.6, which is what we have in CI:

434 passed, 123 skipped, 83 deselected, 1185 warnings in 32.40s

This adds a lot of tests, but they are all small, so the runtime is actually pretty minuscule, and it improves our ONNX import coverage dramatically.

cc @jwfromm @masahi @jroesch @electriclilies @adelbertc

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, just a few nitpicks

channels = infer_channels(inputs[1], True)
out_type = infer_type(inputs[1])
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to work for layouts other than NCHW? It looks like the ONNX op doesn't specify layout in the ConvTranspose operator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX always assumes NCHW

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, just wanted to make sure we didn't have to worry about it!

tests/python/frontend/onnx/test_forward.py Outdated Show resolved Hide resolved
Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR can you just paste you rational above the test section so that new comers understand what's going on?

@jroesch jroesch merged commit 8131364 into apache:main Mar 24, 2021
@mbrookhart mbrookhart deleted the onnx_node_tests branch March 24, 2021 17:34
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
* WIP

* some fixes

* more fixes

* fix some conv_transpose tests

* fix out of bounds slice

* fix flatten import

* fix logsoftmax and softmax tests

* fix Error in Upsample

* fix onehot

* normalize errors

* fix gather with negative indices

* parameterize test

* skip unsupported tests

* clean up

* fix rebase

* fix lint

* add an error message when we find an un-identified tensor
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
* WIP

* some fixes

* more fixes

* fix some conv_transpose tests

* fix out of bounds slice

* fix flatten import

* fix logsoftmax and softmax tests

* fix Error in Upsample

* fix onehot

* normalize errors

* fix gather with negative indices

* parameterize test

* skip unsupported tests

* clean up

* fix rebase

* fix lint

* add an error message when we find an un-identified tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants