diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py
index cd23c0e05d..596ab1700c 100644
--- a/.jenkins/validate_tutorials_built.py
+++ b/.jenkins/validate_tutorials_built.py
@@ -10,6 +10,7 @@
NOT_RUN = [
"beginner_source/basics/intro", # no code
+ "beginner_source/onnx/intro_onnx",
"beginner_source/translation_transformer",
"beginner_source/profiler",
"beginner_source/saving_loading_models",
diff --git a/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
new file mode 100755
index 0000000000..0c29c16879
Binary files /dev/null and b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png differ
diff --git a/_static/img/onnx/netron_web_ui.png b/_static/img/onnx/netron_web_ui.png
new file mode 100755
index 0000000000..f88936eb82
Binary files /dev/null and b/_static/img/onnx/netron_web_ui.png differ
diff --git a/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
new file mode 100755
index 0000000000..00156df042
Binary files /dev/null and b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png differ
diff --git a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
index 426a14d98f..00156df042 100644
Binary files a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png and b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png differ
diff --git a/_templates/layout.html b/_templates/layout.html
index 660c687021..242e347d09 100644
--- a/_templates/layout.html
+++ b/_templates/layout.html
@@ -107,7 +107,7 @@
diff --git a/advanced_source/super_resolution_with_onnxruntime.py b/advanced_source/super_resolution_with_onnxruntime.py
index 835a79bd3a..f0a1894896 100644
--- a/advanced_source/super_resolution_with_onnxruntime.py
+++ b/advanced_source/super_resolution_with_onnxruntime.py
@@ -1,10 +1,17 @@
"""
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
-========================================================================
+===================================================================================
+
+.. Note::
+ As of PyTorch 2.1, there are two versions of ONNX Exporter.
+
+ * ``torch.onnx.dynamo_export`is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
+ * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
In this tutorial, we describe how to convert a model defined
-in PyTorch into the ONNX format and then run it with ONNX Runtime.
+in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export` ONNX exporter.
+The exported model will be executed with ONNX Runtime.
ONNX Runtime is a performance-focused engine for ONNX models,
which inferences efficiently across multiple platforms and hardware
(Windows, Linux, and Mac and on both CPUs and GPUs).
@@ -15,13 +22,17 @@
For this tutorial, you will need to install `ONNX `__
and `ONNX Runtime `__.
You can get binary builds of ONNX and ONNX Runtime with
-``pip install onnx onnxruntime``.
+
+.. code-block:: bash
+
+ %%bash
+ pip install onnxruntime
+
ONNX Runtime recommends using the latest stable runtime for PyTorch.
"""
# Some standard imports
-import io
import numpy as np
from torch import nn
@@ -185,7 +196,7 @@ def _initialize_weights(self):
import onnxruntime
-ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
+ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
diff --git a/beginner_source/onnx/README.txt b/beginner_source/onnx/README.txt
new file mode 100644
index 0000000000..f73ed11bc8
--- /dev/null
+++ b/beginner_source/onnx/README.txt
@@ -0,0 +1,10 @@
+ONNX
+----
+
+1. intro_onnx.py
+ Introduction to ONNX
+ https://pytorch.org/tutorials/onnx/intro_onnx.html
+
+2. export_simple_model_to_onnx_tutorial.py
+ Export a PyTorch model to ONNX
+ https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
diff --git a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
new file mode 100644
index 0000000000..fa09dc86ab
--- /dev/null
+++ b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
@@ -0,0 +1,211 @@
+# -*- coding: utf-8 -*-
+"""
+`Introduction to ONNX `_ ||
+**Export a PyTorch model to ONNX**
+
+Export a PyTorch model to ONNX
+==============================
+
+**Author**: `Thiago Crepaldi `_
+
+.. note::
+ As of PyTorch 2.1, there are two versions of ONNX Exporter.
+
+ * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
+ * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
+
+"""
+
+###############################################################################
+# In the `60 Minute Blitz `_,
+# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
+# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
+# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
+#
+# While PyTorch is great for iterating on the development of models, the model can be deployed to production
+# using different formats, including `ONNX `_ (Open Neural Network Exchange)!
+#
+# ONNX is a flexible open standard format for representing machine learning models which standardized representations
+# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments
+# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone.
+#
+# In this tutorial, we’ll learn how to:
+#
+# 1. Install the required dependencies.
+# 2. Author a simple image classifier model.
+# 3. Export the model to ONNX format.
+# 4. Save the ONNX model in a file.
+# 5. Visualize the ONNX model graph using `Netron `_.
+# 6. Execute the ONNX model with `ONNX Runtime`
+# 7. Compare the PyTorch results with the ones from the ONNX Runtime.
+#
+# 1. Install the required dependencies
+# ------------------------------------
+# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators,
+# we will need to install them.
+#
+# .. code-block:: bash
+#
+# pip install onnx
+# pip install onnxscript
+#
+# 2. Author a simple image classifier model
+# -----------------------------------------
+#
+# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
+# exactly like we did in the `60 Minute Blitz `_.
+#
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MyModel(nn.Module):
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.conv1 = nn.Conv2d(1, 6, 5)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
+ x = F.max_pool2d(F.relu(self.conv2(x)), 2)
+ x = torch.flatten(x, 1)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+######################################################################
+# 3. Export the model to ONNX format
+# ----------------------------------
+#
+# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
+# Next, we can export the model to ONNX format.
+
+torch_model = MyModel()
+torch_input = torch.randn(1, 1, 32, 32)
+export_output = torch.onnx.dynamo_export(torch_model, torch_input)
+
+######################################################################
+# As we can see, we didn't need any code change to the model.
+# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file.
+#
+# 4. Save the ONNX model in a file
+# --------------------------------
+#
+# Although having the exported model loaded in memory is useful in many applications,
+# we can save it to disk with the following code:
+
+export_output.save("my_image_classifier.onnx")
+
+######################################################################
+# The ONNX file can be loaded back into memory and checked if it is well formed with the following code:
+
+import onnx
+onnx_model = onnx.load("my_image_classifier.onnx")
+onnx.checker.check_model(onnx_model)
+
+######################################################################
+# 5. Visualize the ONNX model graph using Netron
+# ----------------------------------------------
+#
+# Now that we have our model saved in a file, we can visualize it with `Netron `_.
+# Netron can either be installed on macos, Linux or Windows computers, or run directly from the browser.
+# Let's try the web version by opening the following link: https://netron.app/.
+#
+# .. image:: ../../_static/img/onnx/netron_web_ui.png
+# :width: 70%
+# :align: center
+#
+#
+# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
+# clicking the **Open model** button.
+#
+# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
+# :width: 50%
+#
+#
+# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron.
+#
+# 6. Execute the ONNX model with ONNX Runtime
+# -------------------------------------------
+#
+# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime.
+#
+# .. code-block:: bash
+#
+# pip install onnxruntime
+#
+# The ONNX standard does not support all the data structure and types that PyTorch does,
+# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime.
+# In our example, the input happens to be the same, but it might have more inputs
+# than the original PyTorch model in more complex models.
+#
+# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU)
+# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value.
+#
+# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input
+# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well.
+
+import onnxruntime
+
+onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input)
+print(f"Input length: {len(onnx_input)}")
+print(f"Sample input: {onnx_input}")
+
+ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
+
+onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
+
+######################################################################
+# 7. Compare the PyTorch results with the ones from the ONNX Runtime
+# -----------------------------------------------------------------
+#
+# The best way to determine whether the exported model is looking good is through numerical evaluation
+# against PyTorch, which is our source of truth.
+#
+# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
+# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.
+
+torch_outputs = torch_model(torch_input)
+torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
+
+assert len(torch_outputs) == len(onnxruntime_outputs)
+for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
+ torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
+
+print("PyTorch and ONNX Runtime output matched!")
+print(f"Output length: {len(onnxruntime_outputs)}")
+print(f"Sample output: {onnxruntime_outputs}")
+
+######################################################################
+# Conclusion
+# ----------
+#
+# That is about it! We have successfully exported our PyTorch model to ONNX format,
+# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime
+# and finally compared its numerical results with PyTorch's.
+#
+# Further reading
+# ---------------
+#
+# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
+# not necessarily in the order they are listed.
+# Feel free to jump directly to specific topics of your interest or
+# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
+#
+# .. include:: /beginner_source/onnx/onnx_toc.txt
+#
+# .. toctree::
+# :hidden:
+#
\ No newline at end of file
diff --git a/beginner_source/onnx/intro_onnx.py b/beginner_source/onnx/intro_onnx.py
new file mode 100644
index 0000000000..05ad3090cc
--- /dev/null
+++ b/beginner_source/onnx/intro_onnx.py
@@ -0,0 +1,59 @@
+"""
+**Introduction to ONNX** ||
+`Export a PyTorch model to ONNX `_
+
+Introduction to ONNX
+====================
+
+Authors:
+`Thiago Crepaldi `_,
+
+`Open Neural Network eXchange (ONNX) `_ is an open standard
+format for representing machine learning models. The ``torch.onnx`` module provides APIs to
+capture the computation graph from a native PyTorch :class:`torch.nn.Module` model and convert
+it into an `ONNX graph `_.
+
+The exported model can be consumed by any of the many
+`runtimes that support ONNX `_,
+including Microsoft's `ONNX Runtime `_.
+
+.. note::
+ Currently, there are two flavors of ONNX exporter APIs,
+ but this tutorial will focus on the ``torch.onnx.dynamo_export``.
+
+The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
+bytecode into an `FX graph `_.
+The resulting FX Graph is polished before it is finally translated into an
+`ONNX graph `_.
+
+The main advantage of this approach is that the `FX graph `_ is captured using
+bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
+
+Dependencies
+------------
+
+The ONNX exporter depends on extra Python packages:
+
+ - `ONNX `_
+ - `ONNX Script `_
+
+They can be installed through `pip `_:
+
+.. code-block:: bash
+
+ pip install --upgrade onnx onnxscript
+
+Further reading
+---------------
+
+The list below refers to tutorials that ranges from basic examples to advanced scenarios,
+not necessarily in the order they are listed.
+Feel free to jump directly to specific topics of your interest or
+sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
+
+.. include:: /beginner_source/onnx/onnx_toc.txt
+
+.. toctree::
+ :hidden:
+
+"""
diff --git a/beginner_source/onnx/onnx_toc.txt b/beginner_source/onnx/onnx_toc.txt
new file mode 100644
index 0000000000..2386430ba7
--- /dev/null
+++ b/beginner_source/onnx/onnx_toc.txt
@@ -0,0 +1 @@
+| 1. `Export a PyTorch model to ONNX `_
\ No newline at end of file
diff --git a/en-wordlist.txt b/en-wordlist.txt
index ee2c79b6b4..4ed4d2077c 100644
--- a/en-wordlist.txt
+++ b/en-wordlist.txt
@@ -132,6 +132,7 @@ Lipschitz
logits
Lua
Luong
+macos
MLP
MLPs
MNIST
@@ -147,11 +148,14 @@ NTK
NUMA
NaN
NanoGPT
+Netron
NeurIPS
NumPy
Numericalization
Numpy's
ONNX
+ONNX's
+ONNX Runtime
OpenAI
OpenMP
Ornstein
@@ -386,6 +390,7 @@ prewritten
primals
profiler
profilers
+protobuf
py
pytorch
quantized
diff --git a/index.rst b/index.rst
index 3070002466..feee988b0e 100644
--- a/index.rst
+++ b/index.rst
@@ -272,6 +272,15 @@ What's new in PyTorch tutorials?
:tags: Text
+.. ONNX
+
+.. customcarditem::
+ :header: (optional) Exporting a PyTorch model to ONNX using TorchDynamo backend and Running it using ONNX Runtime
+ :card_description: Build a image classifier model in PyTorch and convert it to ONNX before deploying it with ONNX Runtime.
+ :image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
+ :link: beginner/onnx/export_simple_model_to_onnx_tutorial.html
+ :tags: Production,ONNX,Backends
+
.. Reinforcement Learning
.. customcarditem::
@@ -329,11 +338,12 @@ What's new in PyTorch tutorials?
:tags: Production,TorchScript
.. customcarditem::
- :header: (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
+ :header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime
:card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime.
:image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
:link: advanced/super_resolution_with_onnxruntime.html
- :tags: Production
+ :tags: Production,ONNX
+
.. Code Transformations with FX
@@ -902,6 +912,14 @@ Additional Resources
beginner/torchtext_custom_dataset_tutorial
+.. toctree::
+ :maxdepth: 2
+ :includehidden:
+ :hidden:
+ :caption: Backends
+
+ beginner/onnx/intro_onnx
+
.. toctree::
:maxdepth: 2
:includehidden:
@@ -918,6 +936,7 @@ Additional Resources
:hidden:
:caption: Deploying PyTorch Models in Production
+ beginner/onnx/intro_onnx
intermediate/flask_rest_api_tutorial
beginner/Intro_to_TorchScript_tutorial
advanced/cpp_export
diff --git a/intermediate_source/memory_format_tutorial.py b/intermediate_source/memory_format_tutorial.py
index f08980265d..26bc5c9d53 100644
--- a/intermediate_source/memory_format_tutorial.py
+++ b/intermediate_source/memory_format_tutorial.py
@@ -131,7 +131,7 @@
# produces output in contiguous memory format. Otherwise, output will
# be in channels last memory format.
-if torch.backends.cudnn.version() >= 7603:
+if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
diff --git a/requirements.txt b/requirements.txt
index 84c35e78d0..31b3f0ad16 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -33,6 +33,9 @@ datasets
transformers
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
deep_phonemizer==0.0.17
+onnx
+onnxscript
+onnxruntime
importlib-metadata==6.8.0