-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch export tutorial #2557
Torch export tutorial #2557
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2557
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 21127d7 with merge base e28eace (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…into torch-export-tutorial
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended | ||
# to be run on high performance environments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its jut a way of getting a full graph out of Pt2, no more or less. There's nothing high performance env opinionated about it. I would make this an example of a point on the former.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more like, the intention of export is to get a full graph and take it to a different environment. Otherwise people would just use torch.compile
# ``torch.export`` is built using the components of ``torch.compile``, | ||
# so it may be helpful to familiarize yourself with ``torch.compile``. | ||
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but unclear if users of an API should care about implementation details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I'm unsure if you need to link torch.compile's tutorial specifically, since you don't really need to know about torch.compile to use torch.export. But, if you did want to mention the similarities between the two, you could link to https://pytorch.org/docs/main/export.html#existing-frameworks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to remove the tutorial link
|
||
###################################################################### | ||
# ``torch.export`` returns an ``ExportedProgram``, which is not a ``torch.nn.Module``, | ||
# but can still be ran as a function: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ran -> run
|
||
###################################################################### | ||
# ``ExportedProgram`` has some attributes that are of interest. | ||
# The ``graph`` attribute is an FX graph traced from the function we exported, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link to fx docs?
print(exported_bad1_fixed(-torch.ones(3, 3))) | ||
|
||
###################################################################### | ||
# There are some limitations one should be aware of: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this suffices :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Just a few suggestions
# with the export use case. | ||
# | ||
# A graph break is necessary in cases such as: | ||
# - data-dependent control flow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there other cases for graph breaks? Like unsupported python operations? Otherwise this doesn't need to be a list.
EDIT: looks like other examples are below, perhaps the formatting should be improved for this first section?
|
||
exported_map_example= export(map_example, (torch.randn(4, 3),)) | ||
inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3))) | ||
print(exported_map_example(inp)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, it's a bit unclear what map
is doing, can we get the example output to understand better?
Perhaps also we can add a comment saying what shape x
is
def false_fn(x): | ||
return x - const | ||
return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x]) | ||
return control_flow.map(map_fn, xs, torch.tensor([2.0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the map
signature like functools.partial
in that it can bind additional arguments?
Is this because something like lamda x: map_fn(x, torch.tensor([2.0])
doesn't work?
Do the additional arguments all have to be tensors?
return torch.relu(x) | ||
|
||
constraints1 = [ | ||
dynamic_dim(inp1, 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What bound does this specify? Is this just saying that dimension 0 exists?
|
||
def constraints_example4(x, y): | ||
b = y.item() | ||
constrain_as_value(b, 3, 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if the constraint is not true for some input? Is it an exception at export time, or would you get it later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the inputs to export violate the constraints, the error happens at export time (example will be included). If future inputs violate constraints, the error happens when you try to run the exported graph on the future inputs.
# - Define a ``Meta`` implementation of the custom op that returns an empty | ||
# tensor with the same shape as the expected output | ||
|
||
@impl(m, "custom_op", "Meta") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note of naming: since the company is called "Meta", I feel this name might be misinterpreted.
Perhaps an all lowercase "meta" would be a little bit better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe using the exact key "Meta" is required in order to get this working - this is a dispatch-related thing.
# | ||
# ``torch.export`` takes in a callable (including ``torch.nn.Module``s), | ||
# a tuple of positional arguments, and optionally (not shown in the example below), | ||
# a dictionary of keyword arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also a list of constraints?
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended | ||
# to be run on high performance environments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more like, the intention of export is to get a full graph and take it to a different environment. Otherwise people would just use torch.compile
# ``torch.export`` is built using the components of ``torch.compile``, | ||
# so it may be helpful to familiarize yourself with ``torch.compile``. | ||
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I'm unsure if you need to link torch.compile's tutorial specifically, since you don't really need to know about torch.compile to use torch.export. But, if you did want to mention the similarities between the two, you could link to https://pytorch.org/docs/main/export.html#existing-frameworks.
# We can also use ``map``, which applies a function across the first dimension | ||
# of the first tensor argument. | ||
|
||
from functorch.experimental.control_flow import map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want map in the tutorial because it won't be fully released by 2.1 or PTC.
tb.print_exc() | ||
|
||
###################################################################### | ||
# Control Flow Ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this section on cond in the tutorial? cc @gmagogsfm @ydwu4
The tutorial will be released alongside 2.1. cond isn't mentioned in the export docs, but it's technically in ExportDB.
# Constraints | ||
# ----------- | ||
# Ops can have different specializations for different tensor shapes, so | ||
# ``ExportedProgram``s uses constraints on tensor shapes in order to ensure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# ``ExportedProgram``s uses constraints on tensor shapes in order to ensure | |
# ``export`` specializes on input tensors' shapes used for tracing, requiring users to provide the same shaped tensors to the ExportedProgram. |
# By default, ``torch.export`` requires all tensors to have the same shape | ||
# as the example inputs, but we can modify the ``torch.export`` call to | ||
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to | ||
# express shape constraints manually. | ||
# | ||
# We can use ``dynamic_dim`` to remove a dimension's constraints, or to | ||
# manually provide an upper or lower bound. In the example below, our input | ||
# ``inp1`` has an unconstrained first dimension, but the size of the second | ||
# dimension must be in the interval (1, 18]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not exactly -- we use dynamic_dim to specify which dimensions are dynamic, and can provide shape constraints on top of these dynamic_dims.
Just doing dynamic_dim(inp1, 0)
tells us that dim 0 of inp1 is dynamic, but doing 3 < dynamic_dim(inp1, 1), dynamic_dim(inp1, 1) <= 18
tells us that the dimension is bounded between (3, 18].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for doing this! Just some small comments
# ``torch.export`` is built using the components of ``torch.compile``, | ||
# so it may be helpful to familiarize yourself with ``torch.compile``. | ||
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to remove the tutorial link
# By default, ``torch.export`` requires all tensors to have the same shape | ||
# as the example inputs, but we can modify the ``torch.export`` call to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# By default, ``torch.export`` requires all tensors to have the same shape | |
# as the example inputs, but we can modify the ``torch.export`` call to | |
# By default, ``torch.export`` requires all future input tensors to have the same shape | |
# as their respective example input, but we can modify the ``torch.export`` call to |
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)`` | ||
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should point out that the mutation is represented via additional output of updated values
|
||
###################################################################### | ||
# Other attributes of interest in ``ExportedProgram`` include: | ||
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we expand a bit more on graph_signature? (maybe even pytree)? so that users can get an intuitive understanding about how inputs/outputs are organized.
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor. | ||
# - The operands (i.e. ``[x]``) must be tensors. | ||
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the | ||
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.) | |
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.) |
|
||
|
||
###################################################################### | ||
# - calling unsupported functions (e.g. many builtins) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# - calling unsupported functions (e.g. many builtins) | |
# - calling unsupported functions, such as many ``builtins`` |
# - The operands (i.e. ``[x]``) must be tensors. | ||
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the | ||
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.) | ||
# - Branch functions cannot mutate inputs or globals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# - Branch functions cannot mutate inputs or globals | |
# - Branch functions cannot mutate ``inputs`` or ``globals`` |
""" | ||
|
||
###################################################################### | ||
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended | |
# :func:`torch.export` enables you to export PyTorch models intended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep "the PyTorch 2.0 way" since it gives a brief reason as to why export is a significant, new way to export a model.
|
||
###################################################################### | ||
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended | ||
# to be run on high performance environments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# to be run on high performance environments. | |
# to be run on high performance environments into static and standardized model representation. |
# ``ExportedProgram`` has some attributes that are of interest. | ||
# The ``graph`` attribute is an FX graph traced from the function we exported, | ||
# that is, the computation graph of all PyTorch operations. | ||
# The FX graph has some important properties: | ||
# - The operations are "ATen-level" operations. | ||
# - The graph is "functionalized", meaning that no operations are mutations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# ``ExportedProgram`` has some attributes that are of interest. | |
# The ``graph`` attribute is an FX graph traced from the function we exported, | |
# that is, the computation graph of all PyTorch operations. | |
# The FX graph has some important properties: | |
# - The operations are "ATen-level" operations. | |
# - The graph is "functionalized", meaning that no operations are mutations. | |
# Let's review some of the important ``ExportedProgram`` attributes that. | |
# The ``graph`` attribute is an FX graph traced from the function we exported, | |
# that is, the computation graph of all PyTorch operations. | |
# | |
# The FX graph has some important properties: | |
# | |
# - The operations are "ATen-level" operations. | |
# - The graph is "functionalized", meaning that no operations are mutations. |
# The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``) | ||
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)`` | ||
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``) | |
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)`` | |
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate). | |
# The printed code shows that FX graph only contains ATen-level ops, such as ``torch.ops.aten``, | |
# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)`` | |
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate. |
# the unsupported operation with default Python evaluation, which is incompatible | ||
# with the export use case. | ||
# | ||
# A graph break is necessary in cases such as: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# A graph break is necessary in cases such as: | |
# A graph break is necessary in the following cases: | |
# |
""" | ||
|
||
###################################################################### | ||
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "PyTorch 2," not "PyTorch 2.0"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan we are using PyTorch 2.0
in a lot of places as title for poster etc.: https://events.linuxfoundation.org/pytorch-conference/program/schedule/ is that OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we say 2.x so that it stays relevant to all 2.x versions?
|
||
###################################################################### | ||
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into | ||
# static and standardized model representations, intended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does "static" mean here?
""" | ||
|
||
###################################################################### | ||
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, you should say "trace-based" somewhere here -- it's really critical to understand that and right now it's dropped into the middle of a sentence about the FX graph format.
# the unsupported operation with default Python evaluation, which is incompatible | ||
# with the export use case. | ||
# | ||
# A graph break is necessary in the following cases: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should try to enumerate the cases where graph breaks are necessary here. A more straightforward description would be something like:
- Therefore, a lot of the work in making a PyTorch program "exportable" is around removing graph breaks.
- An example or two
- For more information and examples of removing graph breaks, see ExportDB.
# The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release. | ||
# Backwards compatibility is not guaranteed. We anticipate releasing a more stable constraints API in the future. | ||
# | ||
# Ops can have different specializations/behaviors for different tensor shapes, so by default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example of where this could go wrong would be useful. I think you are just demonstrating that the exported program will check for you (is that right?), but a motivation for why this is necessary would be useful.
###################################################################### | ||
# We can actually use ``torch.export`` to guide us as to which constraints | ||
# are necessary. We can do this by relaxing all constraints and letting ``torch.export`` | ||
# error out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little strange to "relax all constraints" via inserting things in the constraints region, this probably deserves explanation (that by default there is an implicit constraint that all dimensions are the traced static value.
###################################################################### | ||
# ExportDB | ||
# -------- | ||
# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as noted above, I think this is a really important point and should be earlier in the tutorial.
@@ -7,25 +7,42 @@ | |||
""" | |||
|
|||
###################################################################### | |||
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into | |||
# static and standardized model representations, intended | |||
# :func:`torch.export` is the PyTorch 2 way to export PyTorch models into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamwen42 Can we add a clear enough warning at the top of the tutorial that says torch.export is in prototype status, this tutorial only applies to 2.1 snapshot and there will be BC breaking changes in subsequent versions?
Some of the planned BC breaking changes:
https://github.com/pytorch/pytorch/pull/108448/files/e2f4c3d6ed1e78665a0ad56022099d025f2010a2..3f1fb6710841d57d142c33d99fe183d0ecd15434
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
almost there!!
|
||
###################################################################### | ||
# ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``, | ||
# but it can still be run as a function: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# but it can still be run as a function: | |
# but it is callable and can be run in the same way as the input callable. |
# example inputs given to the initial ``torch.export`` call. | ||
# If we try to run the first ``ExportedProgram`` example with a tensor | ||
# with a different shape, we get an error: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be good to copy this example back down here since users might've forgotten how the exported_mod looked like?
dynamic_dim(inp1, 1) <= 18, | ||
] | ||
|
||
exported_constraints_example1 = export(constraints_example1, (inp1,), constraints=constraints1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be helpful to also print(exported_constraints_example1.range_constraints) to show that the ranges for the symbolic shapes are there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cover printing out range_constraints
and equality_constraints
a little bit below.
# of ``graph``: | ||
|
||
print(exported_mod) | ||
exported_mod.graph_module.print_readable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just figured out how to view the actual thing: https://docs-preview.pytorch.org/pytorch/tutorials/2557/intermediate/torch_export_tutorial.html
It looks like you need to do print(exported_mod.graph_module.print_readable())
print(exported_mod.range_constraints) | ||
print(exported_mod.equality_constraints) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should print these here, because they're empty. It's more useful if we print it with the dynamic_dims tracing.
export(constraints_example3, (inp4, inp5), constraints=constraints3) | ||
except Exception: | ||
tb.print_exc() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to separate these two blocks -- first block prints the stacktrace, and then second block shows that now with the specify_constraints() function, it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks!!
- Add Torch export tutorial (#2557) - Temporarily pull 2.1 binaries in the stable branch - Update wordlist --------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: William Wen <[email protected]>
Ready for review