Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for torch.utils.checkpoint #1127

Merged
merged 6 commits into from
Oct 18, 2024

Conversation

IvanYashchuk
Copy link
Collaborator

A checkpointed function doesn't save any intermediates from forward to backward. Instead, all required values are recomputed during the backward pass. Because less intermediates are saved, peak memory usage is usually decreased.

This PR introduces the support of recognizing torch.utils.checkpoint.checkpoint calls and inserting a new bound symbol in the initial trace. Then in the forward-backward generation pass this bound symbol is converted into augmented forward and backward parts of the computation. This step requires the function argument to thunder.torch.checkpoint be a Thunder function. Currently, there's no conversion PyTorch->Thunder implemented and this works only for simple functions that are both recognized by Thunder and PyTorch, for example when only methods are used.

The PyTorch function needs to be converted to a Thunder function in Thunder's JIT. Previously we could simply use thunder.preprocess which is not available today. When I attempted implementing a redispatching/reinterpretation of PyTorch functions using general_thunder_jit I hit the following bug: #1126.

Example:

import thunder
import torch

def f(x):
    return torch.utils.checkpoint.checkpoint(lambda x: x.sin().cos().exp(), x)

jf = thunder.jit(f)
x = torch.randn(3, 4, device="cuda", requires_grad=True)
jf(x).backward(x)
print(thunder.last_traces(jf)[-1])
print(thunder.last_backward_traces(jf)[-1])

Forward execution trace:

def augmented_forward_fn(x):
  # x: "cuda:0 f32[3, 4]"
  [t2] = nvFusion0(x)
    # t0 = prims.sin(x)  # t0: "cuda:0 f32[3, 4]"
    # t1 = prims.cos(t0)  # t1: "cuda:0 f32[3, 4]"
    # t2 = prims.exp(t1)  # t2: "cuda:0 f32[3, 4]"
  return {'output': t2, 'flat_args': [x], 'flat_output': (t2,)}, ((x,), ())

Backward execution trace:

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t3, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  x, = C0
  clear_mutable_collection(C0)
  del C0
  [t12] = nvFusion0(x, t3)
    # t4 = prims.sin(x)  # t4: "cuda:0 f32[3, 4]"
    # t11 = prims.cos(x)  # t11: "cuda:0 f32[3, 4]"
    # t5 = prims.cos(t4)  # t5: "cuda:0 f32[3, 4]"
    # t8 = prims.sin(t4)  # t8: "cuda:0 f32[3, 4]"
    # t6 = prims.exp(t5)  # t6: "cuda:0 f32[3, 4]"
    # t7 = prims.mul(t3, t6)  # t7: "cuda:0 f32[3, 4]"
    # t9 = prims.neg(t8)  # t9: "cuda:0 f32[3, 4]"
    # t10 = prims.mul(t7, t9)  # t10: "cuda:0 f32[3, 4]"
    # t12 = prims.mul(t10, t11)  # t12: "cuda:0 f32[3, 4]"
  del x, t3
  return (t12,)

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! @t-vi, do you want to take a look?

t-vi
t-vi previously requested changes Sep 9, 2024
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced of the design.
Why would we not just let the function be any function and have a state "currently checkpointing" that informs Thunder to add a tag to the proxies that are generated during the checkpointing instead?
We would need to clear that tag on the outputs, but that would be easier than reentrant jit and higher order functions.

@IvanYashchuk
Copy link
Collaborator Author

Why would we not just let the function be any function and have a state "currently checkpointing" that informs Thunder to add a tag to the proxies that are generated during the checkpointing instead? We would need to clear that tag on the outputs, but that would be easier than reentrant jit.

Do you have ideas about how the "currently checkpointing" approach would generalize to supporting, for example, torch.cond? Please continue in the issue #1134.

@t-vi
Copy link
Collaborator

t-vi commented Sep 10, 2024

I don't have immediate ideas, but I don't see that we should be having higher order functions right now.
If anything it's the wrong sequencing.

@syed-ahmed
Copy link
Collaborator

@IvanYashchuk You might wanna checkout selective activation checkpointing available in PyTorch nightlies: https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts to specify which activations to save for backward.

@IvanYashchuk
Copy link
Collaborator Author

@IvanYashchuk You might wanna checkout selective activation checkpointing available in PyTorch nightlies: https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts to specify which activations to save for backward.

Awesome, thanks for the link, Syed! Not a fan of ATen ops leaking into the PyTorch Python interface with torch.matmul becoming torch.ops.aten.mm.default, but I will check out how it could be recognized by Thunder.

@IvanYashchuk IvanYashchuk merged commit 3f3d46a into main Oct 18, 2024
37 checks passed
@IvanYashchuk IvanYashchuk deleted the functional-autograd-checkpoint branch October 18, 2024 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants