You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for PyTorch Callable -> Thunder Callable translation in Thunder JIT.
Motivation
Several PyTorch operators accept a Python function with PyTorch operations inside as one of their arguments. In PyTorch, these operators are called "higher order operators". Examples of these operators:
Thunder should support all of the above operators. It's easy to support only Thunder functions as inputs (example for checkpoint#1127), but the best user experience would be enabled by the automatic translation of user-provided PyTorch callables into Thunder ones while constructing the initial Thunder trace.
An example of torch.cond to support:
importtorchdeftrue_fn(x: torch.Tensor):
returntorch.cos(x)
deffalse_fn(x: torch.Tensor):
returntorch.sin(x)
# Ideally putting thunder.jit decorator should just work, this requires translation of true_fn and false_fn into Thunder functions so that the insides could be traced and understood by the rest of the system# @thunder.jitdeff(true_fn, false_fn, x):
returntorch.cond(x.shape[0] >4, true_fn, false_fn, (x,))
x=torch.ones(5)
print(f(true_fn, false_fn, x))
Alternatives
@t-vi, please fill in this section with details about alternative solutions.
Implement checkpointing via a lookaside that
traces the checkpointed function as is,
sets a flag in the JITCtx that effects the wrap callback to add a rematerialize_for_backward or so proxy tag to proxies that are wrapped (or created and wrapped),
then clears the flag on the outputs.
The other higher order functions are prototypes currently, barring other pressing needs I think this should inform our prioritization. It would be a formidable change to the nature of traces to have higher order functions in them.
Before looking at this usecase, it would be good to figure out "call jitted function / module from jitted function" first, I guess this would be very useful for jitting training loops with optimizer steps.
Additional context
An attempt at using jit inside lookasides currently fails: #1126.
The text was updated successfully, but these errors were encountered:
Given that all the functions torch.utils.checkpoint.checkpoint are marked as
.. warning::
`torch.associative_scan` is a prototype feature in PyTorch. It currently
does not support autograd and you may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
we can probably move higher-order functions to a longer-term discussion as support for those mature in PyTorch.
torch.utils.checkpoint.checkpoint looks to me more like a case of "wrapping" rather than one of a general higher-order function.
Specifically I'd like to understand the implications of having higher order functions in traces, from the point of view of transform authors, and ensuring everything keeps working with everything.
🚀 Feature
Add support for PyTorch Callable -> Thunder Callable translation in Thunder JIT.
Motivation
Several PyTorch operators accept a Python function with PyTorch operations inside as one of their arguments. In PyTorch, these operators are called "higher order operators". Examples of these operators:
Pitch
Thunder should support all of the above operators. It's easy to support only Thunder functions as inputs (example for
checkpoint
#1127), but the best user experience would be enabled by the automatic translation of user-provided PyTorch callables into Thunder ones while constructing the initial Thunder trace.An example of
torch.cond
to support:Alternatives
@t-vi, please fill in this section with details about alternative solutions.
Implement checkpointing via a lookaside that
rematerialize_for_backward
or so proxy tag to proxies that are wrapped (or created and wrapped),The other higher order functions are prototypes currently, barring other pressing needs I think this should inform our prioritization. It would be a formidable change to the nature of traces to have higher order functions in them.
Before looking at this usecase, it would be good to figure out "call jitted function / module from jitted function" first, I guess this would be very useful for jitting training loops with optimizer steps.
Additional context
An attempt at using jit inside lookasides currently fails: #1126.
The text was updated successfully, but these errors were encountered: