Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX 2] Add ONNX tutorial using torch.onnx.dynamo_export API #2541

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

NOT_RUN = [
"beginner_source/basics/intro", # no code
"beginner_source/onnx/intro_onnx",
"beginner_source/translation_transformer",
"beginner_source/profiler",
"beginner_source/saving_loading_models",
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/onnx/netron_web_ui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion _templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
</noscript>

<script type="text/javascript">
var collapsedSections = ['PyTorch Recipes', 'Learning PyTorch', 'Image and Video', 'Audio', 'Text', 'Reinforcement Learning', 'Deploying PyTorch Models in Production', 'Code Transforms with FX', 'Frontend APIs', 'Extending PyTorch', 'Model Optimization', 'Parallel and Distributed Training', 'Mobile'];
var collapsedSections = ['PyTorch Recipes', 'Learning PyTorch', 'Image and Video', 'Audio', 'Text', 'Backends', 'Reinforcement Learning', 'Deploying PyTorch Models in Production', 'Code Transforms with FX', 'Frontend APIs', 'Extending PyTorch', 'Model Optimization', 'Parallel and Distributed Training', 'Mobile'];
</script>

<img height="1" width="1" style="border-style:none;" alt="" src="https://www.googleadservices.com/pagead/conversion/795629140/?label=txkmCPmdtosBENSssfsC&amp;guid=ON&amp;script=0"/>
Expand Down
21 changes: 16 additions & 5 deletions advanced_source/super_resolution_with_onnxruntime.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
========================================================================
===================================================================================

.. Note::
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
As of PyTorch 2.1, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0

In this tutorial, we describe how to convert a model defined
in PyTorch into the ONNX format and then run it with ONNX Runtime.
in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export` ONNX exporter.

The exported model will be executed with ONNX Runtime.
ONNX Runtime is a performance-focused engine for ONNX models,
which inferences efficiently across multiple platforms and hardware
(Windows, Linux, and Mac and on both CPUs and GPUs).
Expand All @@ -15,13 +22,17 @@
For this tutorial, you will need to install `ONNX <https://github.com/onnx/onnx>`__
and `ONNX Runtime <https://github.com/microsoft/onnxruntime>`__.
You can get binary builds of ONNX and ONNX Runtime with
``pip install onnx onnxruntime``.

.. code-block:: bash

%%bash
pip install onnxruntime

ONNX Runtime recommends using the latest stable runtime for PyTorch.

"""

# Some standard imports
import io
import numpy as np

from torch import nn
Expand Down Expand Up @@ -185,7 +196,7 @@ def _initialize_weights(self):

import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
Expand Down
10 changes: 10 additions & 0 deletions beginner_source/onnx/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ONNX
----

1. intro_onnx.py
Introduction to ONNX
https://pytorch.org/tutorials/onnx/intro_onnx.html

2. export_simple_model_to_onnx_tutorial.py
Export a PyTorch model to ONNX
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
211 changes: 211 additions & 0 deletions beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
**Export a PyTorch model to ONNX**

Export a PyTorch model to ONNX
==============================

**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_

.. note::
As of PyTorch 2.1, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0

"""

###############################################################################
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
#
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
#
# ONNX is a flexible open standard format for representing machine learning models which standardized representations
# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments
# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone.
#
# In this tutorial, we’ll learn how to:
#
# 1. Install the required dependencies.
# 2. Author a simple image classifier model.
# 3. Export the model to ONNX format.
# 4. Save the ONNX model in a file.
# 5. Visualize the ONNX model graph using `Netron <https://github.com/lutzroeder/netron>`_.
# 6. Execute the ONNX model with `ONNX Runtime`
# 7. Compare the PyTorch results with the ones from the ONNX Runtime.
#
# 1. Install the required dependencies
# ------------------------------------
# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators,
# we will need to install them.
#
# .. code-block:: bash
#
# pip install onnx
# pip install onnxscript
#
# 2. Author a simple image classifier model
# -----------------------------------------
#
# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
# exactly like we did in the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_.
#

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):

def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

######################################################################
# 3. Export the model to ONNX format
# ----------------------------------
#
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
# Next, we can export the model to ONNX format.

torch_model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
export_output = torch.onnx.dynamo_export(torch_model, torch_input)

######################################################################
# As we can see, we didn't need any code change to the model.
# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file.
#
# 4. Save the ONNX model in a file
# --------------------------------
#
# Although having the exported model loaded in memory is useful in many applications,
# we can save it to disk with the following code:

export_output.save("my_image_classifier.onnx")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work in the Google Colab?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to test with this pr preview? in theory, if the filesystem has read-write access, it should

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might need to mount a Google drive similar to this: https://pytorch.org/tutorials/beginner/colab.html#using-tutorial-data-from-google-drive-in-colab. Can maybe give a link to this section.


######################################################################
# The ONNX file can be loaded back into memory and checked if it is well formed with the following code:

import onnx
onnx_model = onnx.load("my_image_classifier.onnx")
onnx.checker.check_model(onnx_model)

######################################################################
# 5. Visualize the ONNX model graph using Netron
# ----------------------------------------------
#
# Now that we have our model saved in a file, we can visualize it with `Netron <https://github.com/lutzroeder/netron>`_.
# Netron can either be installed on macos, Linux or Windows computers, or run directly from the browser.
# Let's try the web version by opening the following link: https://netron.app/.
#
# .. image:: ../../_static/img/onnx/netron_web_ui.png
# :width: 70%
# :align: center
#
#
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
# clicking the **Open model** button.
#
# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
# :width: 50%
#
#
# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron.
#
# 6. Execute the ONNX model with ONNX Runtime
# -------------------------------------------
#
# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime.
#
# .. code-block:: bash
#
# pip install onnxruntime
#
# The ONNX standard does not support all the data structure and types that PyTorch does,
# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime.
# In our example, the input happens to be the same, but it might have more inputs
# than the original PyTorch model in more complex models.
#
# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU)
# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value.
#
# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input
# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well.

import onnxruntime

onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input)
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")

ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

######################################################################
# 7. Compare the PyTorch results with the ones from the ONNX Runtime
# -----------------------------------------------------------------
#
# The best way to determine whether the exported model is looking good is through numerical evaluation
# against PyTorch, which is our source of truth.
#
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.

torch_outputs = torch_model(torch_input)
torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

print("PyTorch and ONNX Runtime output matched!")
print(f"Output length: {len(onnxruntime_outputs)}")
print(f"Sample output: {onnxruntime_outputs}")

######################################################################
# Conclusion
# ----------
#
# That is about it! We have successfully exported our PyTorch model to ONNX format,
# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime
# and finally compared its numerical results with PyTorch's.
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
#
# Further reading
# ---------------
#
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
# not necessarily in the order they are listed.
# Feel free to jump directly to specific topics of your interest or
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
#
# .. include:: /beginner_source/onnx/onnx_toc.txt
#
# .. toctree::
# :hidden:
#
59 changes: 59 additions & 0 deletions beginner_source/onnx/intro_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
**Introduction to ONNX** ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consolidate all this in this PR: #2550?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That depends on when ONNX Runtime 1.16 will be released. This PR doesn't pass CI without it while #2550 does not need it.

I am monitoring the ort 1.16 release and will proceed as needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@svekars ORT 1.16 is out and the CI issue seems to be gone

There is another unrelated one, though:

wget -nv 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O _data/lenet_mnist_model.pth
https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl:
2023-09-20 19:46:41 ERROR 429: Too Many Requests.

`Export a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_

Introduction to ONNX
====================

Authors:
`Thiago Crepaldi <https://github.com/thiagocrepaldi>`_,

`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
format for representing machine learning models. The ``torch.onnx`` module provides APIs to
capture the computation graph from a native PyTorch :class:`torch.nn.Module` model and convert
it into an `ONNX graph <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.

The exported model can be consumed by any of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_,
including Microsoft's `ONNX Runtime <https://www.onnxruntime.ai>`_.

.. note::
Currently, there are two flavors of ONNX exporter APIs,
but this tutorial will focus on the ``torch.onnx.dynamo_export``.

The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
bytecode into an `FX graph <https://pytorch.org/docs/stable/fx.html>`_.
The resulting FX Graph is polished before it is finally translated into an
`ONNX graph <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.

The main advantage of this approach is that the `FX graph <https://pytorch.org/docs/stable/fx.html>`_ is captured using
bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.

Dependencies
------------

The ONNX exporter depends on extra Python packages:

- `ONNX <https://onnx.ai>`_
- `ONNX Script <https://onnxscript.ai>`_

They can be installed through `pip <https://pypi.org/project/pip/>`_:

.. code-block:: bash

pip install --upgrade onnx onnxscript

Further reading
---------------

The list below refers to tutorials that ranges from basic examples to advanced scenarios,
not necessarily in the order they are listed.
Feel free to jump directly to specific topics of your interest or
sit tight and have fun going through all of them to learn all there is about the ONNX exporter.

.. include:: /beginner_source/onnx/onnx_toc.txt

.. toctree::
:hidden:

"""
1 change: 1 addition & 0 deletions beginner_source/onnx/onnx_toc.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
| 1. `Export a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_
5 changes: 5 additions & 0 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Lipschitz
logits
Lua
Luong
macos
MLP
MLPs
MNIST
Expand All @@ -147,11 +148,14 @@ NTK
NUMA
NaN
NanoGPT
Netron
NeurIPS
NumPy
Numericalization
Numpy's
ONNX
ONNX's
ONNX Runtime
OpenAI
OpenMP
Ornstein
Expand Down Expand Up @@ -386,6 +390,7 @@ prewritten
primals
profiler
profilers
protobuf
py
pytorch
quantized
Expand Down
Loading
Loading