Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch export tutorial #2557

Merged
merged 12 commits into from
Sep 22, 2023
Merged

Torch export tutorial #2557

merged 12 commits into from
Sep 22, 2023

Conversation

williamwen42
Copy link
Member

@williamwen42 williamwen42 commented Sep 12, 2023

Ready for review

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2557

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 21127d7 with merge base e28eace (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@svekars svekars changed the base branch from main to 2.1-RC-TEST September 12, 2023 15:08
@svekars svekars added the 2.1 label Sep 12, 2023
Comment on lines 10 to 11
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
# to be run on high performance environments.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its jut a way of getting a full graph out of Pt2, no more or less. There's nothing high performance env opinionated about it. I would make this an example of a point on the former.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more like, the intention of export is to get a full graph and take it to a different environment. Otherwise people would just use torch.compile

Comment on lines 13 to 15
# ``torch.export`` is built using the components of ``torch.compile``,
# so it may be helpful to familiarize yourself with ``torch.compile``.
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but unclear if users of an API should care about implementation details

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I'm unsure if you need to link torch.compile's tutorial specifically, since you don't really need to know about torch.compile to use torch.export. But, if you did want to mention the similarities between the two, you could link to https://pytorch.org/docs/main/export.html#existing-frameworks.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to remove the tutorial link


######################################################################
# ``torch.export`` returns an ``ExportedProgram``, which is not a ``torch.nn.Module``,
# but can still be ran as a function:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ran -> run


######################################################################
# ``ExportedProgram`` has some attributes that are of interest.
# The ``graph`` attribute is an FX graph traced from the function we exported,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to fx docs?

print(exported_bad1_fixed(-torch.ones(3, 3)))

######################################################################
# There are some limitations one should be aware of:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this suffices :)

Copy link

@dulinriley dulinriley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Just a few suggestions

# with the export use case.
#
# A graph break is necessary in cases such as:
# - data-dependent control flow

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other cases for graph breaks? Like unsupported python operations? Otherwise this doesn't need to be a list.

EDIT: looks like other examples are below, perhaps the formatting should be improved for this first section?

intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved
intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved

exported_map_example= export(map_example, (torch.randn(4, 3),))
inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
print(exported_map_example(inp))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example, it's a bit unclear what map is doing, can we get the example output to understand better?
Perhaps also we can add a comment saying what shape x is

def false_fn(x):
return x - const
return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
return control_flow.map(map_fn, xs, torch.tensor([2.0]))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the map signature like functools.partial in that it can bind additional arguments?
Is this because something like lamda x: map_fn(x, torch.tensor([2.0]) doesn't work?
Do the additional arguments all have to be tensors?

return torch.relu(x)

constraints1 = [
dynamic_dim(inp1, 0),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What bound does this specify? Is this just saying that dimension 0 exists?

intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved

def constraints_example4(x, y):
b = y.item()
constrain_as_value(b, 3, 5)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the constraint is not true for some input? Is it an exception at export time, or would you get it later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the inputs to export violate the constraints, the error happens at export time (example will be included). If future inputs violate constraints, the error happens when you try to run the exported graph on the future inputs.

intermediate_source/torch_export_tutorial.py Show resolved Hide resolved
Comment on lines 398 to 401
# - Define a ``Meta`` implementation of the custom op that returns an empty
# tensor with the same shape as the expected output

@impl(m, "custom_op", "Meta")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note of naming: since the company is called "Meta", I feel this name might be misinterpreted.
Perhaps an all lowercase "meta" would be a little bit better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe using the exact key "Meta" is required in order to get this working - this is a dispatch-related thing.

#
# ``torch.export`` takes in a callable (including ``torch.nn.Module``s),
# a tuple of positional arguments, and optionally (not shown in the example below),
# a dictionary of keyword arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also a list of constraints?

Comment on lines 10 to 11
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
# to be run on high performance environments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more like, the intention of export is to get a full graph and take it to a different environment. Otherwise people would just use torch.compile

Comment on lines 13 to 15
# ``torch.export`` is built using the components of ``torch.compile``,
# so it may be helpful to familiarize yourself with ``torch.compile``.
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I'm unsure if you need to link torch.compile's tutorial specifically, since you don't really need to know about torch.compile to use torch.export. But, if you did want to mention the similarities between the two, you could link to https://pytorch.org/docs/main/export.html#existing-frameworks.

intermediate_source/torch_export_tutorial.py Show resolved Hide resolved
intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved
intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved
# We can also use ``map``, which applies a function across the first dimension
# of the first tensor argument.

from functorch.experimental.control_flow import map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want map in the tutorial because it won't be fully released by 2.1 or PTC.

tb.print_exc()

######################################################################
# Control Flow Ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this section on cond in the tutorial? cc @gmagogsfm @ydwu4
The tutorial will be released alongside 2.1. cond isn't mentioned in the export docs, but it's technically in ExportDB.

# Constraints
# -----------
# Ops can have different specializations for different tensor shapes, so
# ``ExportedProgram``s uses constraints on tensor shapes in order to ensure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ``ExportedProgram``s uses constraints on tensor shapes in order to ensure
# ``export`` specializes on input tensors' shapes used for tracing, requiring users to provide the same shaped tensors to the ExportedProgram.

Comment on lines 215 to 223
# By default, ``torch.export`` requires all tensors to have the same shape
# as the example inputs, but we can modify the ``torch.export`` call to
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to
# express shape constraints manually.
#
# We can use ``dynamic_dim`` to remove a dimension's constraints, or to
# manually provide an upper or lower bound. In the example below, our input
# ``inp1`` has an unconstrained first dimension, but the size of the second
# dimension must be in the interval (1, 18].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly -- we use dynamic_dim to specify which dimensions are dynamic, and can provide shape constraints on top of these dynamic_dims.

Just doing dynamic_dim(inp1, 0) tells us that dim 0 of inp1 is dynamic, but doing 3 < dynamic_dim(inp1, 1), dynamic_dim(inp1, 1) <= 18 tells us that the dimension is bounded between (3, 18].

Copy link

@gmagogsfm gmagogsfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for doing this! Just some small comments

Comment on lines 13 to 15
# ``torch.export`` is built using the components of ``torch.compile``,
# so it may be helpful to familiarize yourself with ``torch.compile``.
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to remove the tutorial link

Comment on lines 215 to 216
# By default, ``torch.export`` requires all tensors to have the same shape
# as the example inputs, but we can modify the ``torch.export`` call to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# By default, ``torch.export`` requires all tensors to have the same shape
# as the example inputs, but we can modify the ``torch.export`` call to
# By default, ``torch.export`` requires all future input tensors to have the same shape
# as their respective example input, but we can modify the ``torch.export`` call to

Comment on lines 77 to 78
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)``
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should point out that the mutation is represented via additional output of updated values

intermediate_source/torch_export_tutorial.py Show resolved Hide resolved

######################################################################
# Other attributes of interest in ``ExportedProgram`` include:
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we expand a bit more on graph_signature? (maybe even pytree)? so that users can get an intuitive understanding about how inputs/outputs are organized.

# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
# - The operands (i.e. ``[x]``) must be tensors.
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.)
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.)



######################################################################
# - calling unsupported functions (e.g. many builtins)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - calling unsupported functions (e.g. many builtins)
# - calling unsupported functions, such as many ``builtins``

# - The operands (i.e. ``[x]``) must be tensors.
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.)
# - Branch functions cannot mutate inputs or globals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - Branch functions cannot mutate inputs or globals
# - Branch functions cannot mutate ``inputs`` or ``globals``

"""

######################################################################
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
# :func:`torch.export` enables you to export PyTorch models intended

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep "the PyTorch 2.0 way" since it gives a brief reason as to why export is a significant, new way to export a model.


######################################################################
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
# to be run on high performance environments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# to be run on high performance environments.
# to be run on high performance environments into static and standardized model representation.

Comment on lines 60 to 65
# ``ExportedProgram`` has some attributes that are of interest.
# The ``graph`` attribute is an FX graph traced from the function we exported,
# that is, the computation graph of all PyTorch operations.
# The FX graph has some important properties:
# - The operations are "ATen-level" operations.
# - The graph is "functionalized", meaning that no operations are mutations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ``ExportedProgram`` has some attributes that are of interest.
# The ``graph`` attribute is an FX graph traced from the function we exported,
# that is, the computation graph of all PyTorch operations.
# The FX graph has some important properties:
# - The operations are "ATen-level" operations.
# - The graph is "functionalized", meaning that no operations are mutations.
# Let's review some of the important ``ExportedProgram`` attributes that.
# The ``graph`` attribute is an FX graph traced from the function we exported,
# that is, the computation graph of all PyTorch operations.
#
# The FX graph has some important properties:
#
# - The operations are "ATen-level" operations.
# - The graph is "functionalized", meaning that no operations are mutations.

Comment on lines 76 to 78
# The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``)
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)``
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``)
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)``
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate).
# The printed code shows that FX graph only contains ATen-level ops, such as ``torch.ops.aten``,
# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.

intermediate_source/torch_export_tutorial.py Show resolved Hide resolved
intermediate_source/torch_export_tutorial.py Outdated Show resolved Hide resolved
# the unsupported operation with default Python evaluation, which is incompatible
# with the export use case.
#
# A graph break is necessary in cases such as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# A graph break is necessary in cases such as:
# A graph break is necessary in the following cases:
#

"""

######################################################################
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "PyTorch 2," not "PyTorch 2.0"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan we are using PyTorch 2.0 in a lot of places as title for poster etc.: https://events.linuxfoundation.org/pytorch-conference/program/schedule/ is that OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we say 2.x so that it stays relevant to all 2.x versions?


######################################################################
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
# static and standardized model representations, intended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "static" mean here?

"""

######################################################################
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, you should say "trace-based" somewhere here -- it's really critical to understand that and right now it's dropped into the middle of a sentence about the FX graph format.

# the unsupported operation with default Python evaluation, which is incompatible
# with the export use case.
#
# A graph break is necessary in the following cases:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to enumerate the cases where graph breaks are necessary here. A more straightforward description would be something like:

  • Therefore, a lot of the work in making a PyTorch program "exportable" is around removing graph breaks.
  • An example or two
  • For more information and examples of removing graph breaks, see ExportDB.

# The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release.
# Backwards compatibility is not guaranteed. We anticipate releasing a more stable constraints API in the future.
#
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example of where this could go wrong would be useful. I think you are just demonstrating that the exported program will check for you (is that right?), but a motivation for why this is necessary would be useful.

######################################################################
# We can actually use ``torch.export`` to guide us as to which constraints
# are necessary. We can do this by relaxing all constraints and letting ``torch.export``
# error out.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little strange to "relax all constraints" via inserting things in the constraints region, this probably deserves explanation (that by default there is an implicit constraint that all dimensions are the traced static value.

######################################################################
# ExportDB
# --------
# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as noted above, I think this is a really important point and should be earlier in the tutorial.

@@ -7,25 +7,42 @@
"""

######################################################################
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
# static and standardized model representations, intended
# :func:`torch.export` is the PyTorch 2 way to export PyTorch models into

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamwen42 Can we add a clear enough warning at the top of the tutorial that says torch.export is in prototype status, this tutorial only applies to 2.1 snapshot and there will be BC breaking changes in subsequent versions?

Some of the planned BC breaking changes:
https://github.com/pytorch/pytorch/pull/108448/files/e2f4c3d6ed1e78665a0ad56022099d025f2010a2..3f1fb6710841d57d142c33d99fe183d0ecd15434

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almost there!!


######################################################################
# ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``,
# but it can still be run as a function:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# but it can still be run as a function:
# but it is callable and can be run in the same way as the input callable.

# example inputs given to the initial ``torch.export`` call.
# If we try to run the first ``ExportedProgram`` example with a tensor
# with a different shape, we get an error:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be good to copy this example back down here since users might've forgotten how the exported_mod looked like?

dynamic_dim(inp1, 1) <= 18,
]

exported_constraints_example1 = export(constraints_example1, (inp1,), constraints=constraints1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be helpful to also print(exported_constraints_example1.range_constraints) to show that the ranges for the symbolic shapes are there

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cover printing out range_constraints and equality_constraints a little bit below.

# of ``graph``:

print(exported_mod)
exported_mod.graph_module.print_readable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just figured out how to view the actual thing: https://docs-preview.pytorch.org/pytorch/tutorials/2557/intermediate/torch_export_tutorial.html

It looks like you need to do print(exported_mod.graph_module.print_readable())

Comment on lines 101 to 102
print(exported_mod.range_constraints)
print(exported_mod.equality_constraints)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should print these here, because they're empty. It's more useful if we print it with the dynamic_dims tracing.

export(constraints_example3, (inp4, inp5), constraints=constraints3)
except Exception:
tb.print_exc()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to separate these two blocks -- first block prints the stacktrace, and then second block shows that now with the specify_constraints() function, it works.

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks!!

@williamwen42 williamwen42 merged commit c8c76fb into 2.1-RC-TEST Sep 22, 2023
17 of 18 checks passed
svekars pushed a commit that referenced this pull request Oct 2, 2023
- Add Torch export tutorial (#2557)
- Temporarily pull 2.1 binaries in the stable branch
- Update wordlist
---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
Co-authored-by: William Wen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants