Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use transform_for_execution to get callable for torch compile #1041

Merged
merged 4 commits into from
Aug 26, 2024

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Aug 24, 2024

Fixes #1040

Note that back in the day, @IvanYashchuk added a comment about suspected bugs preventing this approach:

# Here instead of using thunder.trace we could use torch_trace =
# passes._transform_for_operator_executor_execution(region_trace, [torchex])
# but then we would need to handle unpacking of the args explicitly For
# example with:
# try:
# token = set_tracectx(region_trace)
# col = CollectionProxy(region_trace.args, name="args")
# _ = prims.unpack_sequence(col, len(region_trace.args))
# finally:
# reset_tracectx(token)
# region_trace.bound_symbols.extend(bsyms)
# But there are some issues with the
# _transform_for_operator_executor_execution implementation that need to be
# fixed first. One issue is that it doesn't maintain the ssa form of the
# trace, which is needed for all the passes to work correctly.
# TODO: issue "Try using _transform_for_operator_executor_execution for
# torch.compile executor"

Now I seem to get good results with just calling transform_for_execution with only the torchex as int the executors_list. Maybe the bugs have been fixed since.

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped!

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 24, 2024

Seems I hit the exact problem that Ivan described.

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 24, 2024

Turns out that here, it seems the problem was that the region_trace was not enough of a proper trace to work.
Note that the region trace still does not have unpack_trivial at the top.
Maybe we should either forego regions entirely in favour of traces or have a utility region->trace that does it.

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 26, 2024

Merging, but don't hesitate to scream if you don't like it. 😉

@t-vi t-vi merged commit d95ca14 into main Aug 26, 2024
37 checks passed
@t-vi t-vi deleted the tom/torch_compile_ex_callable branch August 26, 2024 09:29
@IvanYashchuk
Copy link
Collaborator

Maybe we should either forego regions entirely in favour of traces

Yes, that would be nice. Regions look like subclasses of TraceCtx with positional-only arguments and automatic identification of arguments and outputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

KeyError: 'type' from torch.compile executor
3 participants