Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reentrant JIT for higher order operators #1134

Open
IvanYashchuk opened this issue Sep 10, 2024 · 2 comments
Open

Reentrant JIT for higher order operators #1134

IvanYashchuk opened this issue Sep 10, 2024 · 2 comments
Labels
design This is a largish feature / design enhancement New feature or request interpreter jit

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Sep 10, 2024

🚀 Feature

Add support for PyTorch Callable -> Thunder Callable translation in Thunder JIT.

Motivation

Several PyTorch operators accept a Python function with PyTorch operations inside as one of their arguments. In PyTorch, these operators are called "higher order operators". Examples of these operators:

Pitch

Thunder should support all of the above operators. It's easy to support only Thunder functions as inputs (example for checkpoint #1127), but the best user experience would be enabled by the automatic translation of user-provided PyTorch callables into Thunder ones while constructing the initial Thunder trace.

An example of torch.cond to support:

import torch

def true_fn(x: torch.Tensor):
    return torch.cos(x)
def false_fn(x: torch.Tensor):
    return torch.sin(x)

# Ideally putting thunder.jit decorator should just work, this requires translation of true_fn and false_fn into Thunder functions so that the insides could be traced and understood by the rest of the system
# @thunder.jit
def f(true_fn, false_fn, x):
    return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

x = torch.ones(5)
print(f(true_fn, false_fn, x))

Alternatives

@t-vi, please fill in this section with details about alternative solutions.

Implement checkpointing via a lookaside that

  • traces the checkpointed function as is,
  • sets a flag in the JITCtx that effects the wrap callback to add a rematerialize_for_backward or so proxy tag to proxies that are wrapped (or created and wrapped),
  • then clears the flag on the outputs.

The other higher order functions are prototypes currently, barring other pressing needs I think this should inform our prioritization. It would be a formidable change to the nature of traces to have higher order functions in them.
Before looking at this usecase, it would be good to figure out "call jitted function / module from jitted function" first, I guess this would be very useful for jitting training loops with optimizer steps.

Additional context

An attempt at using jit inside lookasides currently fails: #1126.

@IvanYashchuk IvanYashchuk added the design This is a largish feature / design label Sep 10, 2024
@lantiga
Copy link
Collaborator

lantiga commented Sep 10, 2024

Given that all the functions torch.utils.checkpoint.checkpoint are marked as

    .. warning::
        `torch.associative_scan` is a prototype feature in PyTorch. It currently
        does not support autograd and you may run into miscompiles.
        Read more about feature classification at:
        https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

we can probably move higher-order functions to a longer-term discussion as support for those mature in PyTorch.

torch.utils.checkpoint.checkpoint looks to me more like a case of "wrapping" rather than one of a general higher-order function.

@lantiga
Copy link
Collaborator

lantiga commented Sep 10, 2024

Specifically I'd like to understand the implications of having higher order functions in traces, from the point of view of transform authors, and ensuring everything keeps working with everything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design This is a largish feature / design enhancement New feature or request interpreter jit
Projects
None yet
Development

No branches or pull requests

2 participants