Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak when raising an exception in jitted fn and catching it outside #1171

Closed
lantiga opened this issue Sep 18, 2024 · 0 comments · Fixed by #1193
Closed

Memory leak when raising an exception in jitted fn and catching it outside #1171

lantiga opened this issue Sep 18, 2024 · 0 comments · Fixed by #1193
Assignees
Labels
dynamo+thunder for things that could be applicable to the dynamo+thunder frontend

Comments

@lantiga
Copy link
Collaborator

lantiga commented Sep 18, 2024

From the analysis of #1170, there's one case where we still leak memory related to exceptions.

When we don't handle the exception within the interpreter but do it outside we leak.

class Identity(torch.nn.Module):
    def forward(self, x):
        raise RuntimeError("Error")
        return x

def foo():
    model = thunder.jit(Identity())
    x = torch.randn(16, 16)

    model(x)

    return weakref.ref(x)

weak_x = foo()

assert weak_x() is None  # this is false, x leaks

The leak seems mainly related to the exception keeping a reference to the function in the traceback:

cycle_9

Full repro:

import weakref
import torch
import thunder
import refcycle  # pip install refcycle


class Identity(torch.nn.Module):
    def forward(self, x):
        raise RuntimeError("FOOBAR")
        return x

def main():
    with torch.device("cpu"):
        model = thunder.jit(Identity())
        x = torch.randn(16, 16)

    try:
        model(x)
    except:
        pass
    return weakref.ref(x)


weak_x = main()

if weak_x() is not None:
    snap = refcycle.snapshot()

    for t in snap.find_by(lambda t: isinstance(t, torch.Tensor)):
        print(t.shape)  # Prints - torch.Size([16, 16])

    # Code to find and save cycles
    for idx, anc in enumerate(snap.ancestors(weak_x())):
        print(idx)
        try:
            cycle = snap.shortest_cycle(anc)
            cycle.export_image(f"cycles/cycle_{idx}.png")
            print(cycle)
            break
        except:
            pass
else:
    print("No leaks!")
@tfogal tfogal added the dynamo+thunder for things that could be applicable to the dynamo+thunder frontend label Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo+thunder for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants