Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move from interpret_trace->transform->construct_trace to working with traces directly #1138

Open
t-vi opened this issue Sep 11, 2024 · 4 comments
Assignees

Comments

@t-vi
Copy link
Collaborator

t-vi commented Sep 11, 2024

The autocast and gradient-related transforms use

interpret_trace->transform->construct_trace

this drops information, e.g. tags on the proxies.
So this issue is making them work with the traces directly.

My idea here is NOT to stop calling symbols but to construct the new trace closer to where we do.

@IvanYashchuk
Copy link
Collaborator

We have tags for Symbols

tags: None | list[OpTags] = None

I see that "tags on proxies" were added in #1048 but there's no documentation about the intended usage

tags: set | None = None,

Can we unify tags on symbols and tags on proxies?

interpret_trace->transform->construct_trace
this drops information, e.g. tags on the proxies.

What does it mean to preserve tags on proxies when in general the transformation is allowed to produce completely new and different outputs and the number of outputs can be different?

If we just had tags on symbols I would suggest modifying the trace interpreter part of the code that calls the symbol with a context manager with current tags to be set on newly created symbols:

prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
if prim_func is None:
continue
result = prim_func(*args, **kwargs)

Are tags on proxies queriable as any other metadata (shape, dtype, device) at the jit-tracing time? Should the grad rule functions be modified then to copy the special tags from inputs to outputs?

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 25, 2024

To my mind, we conceptually need tags on proxies because properties like "STATIC_MEMORY_LOCATION" or "DONT_SAVE_FOR_BACKWARD" are properties of the proxies.
We might have things in the future that tag proxies based on consumers, so putting things on the producing bsyms might not suffice.

The drawback of the style proposed here is that it fundamentally only works at the top level, so any subsymbol's additional info is lost.
A possible alternative solution could be to set up context variables with a (hierarical) "preview" of the old trace so symbols can inherit tags whereever they follow the preview. I'd be totally for exploring that route.

That said, the excercise in #1164 also left me wondering about the apparent tower of complexity in the autograd bits (I am not saying it is unnecessary complexity, just that it is quite nested). It certainly seems to be a source of inconsistencies we have been seen in transform for execution etc.

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 27, 2024

Thinking about @IvanYashchuk 's comment more, I'm warming to the idea that maybe putting a rematerialization tag on the bsym and inheriting it to the subsymbols would be better than putting it on the proxies. WDYT?

@IvanYashchuk
Copy link
Collaborator

Thinking about @IvanYashchuk 's comment more, I'm warming to the idea that maybe putting a rematerialization tag on the bsym and inheriting it to the subsymbols would be better than putting it on the proxies. WDYT?

It can work without much modification to the current code if BoundSymbol creation inside Symbol.__call__ can query somehow what are the current tags that should be added to the bsym

bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)

and inside the trace interpreter before invoking Symbol.__call__ the current tags are set. It could be done for example by wrapping the returned function from vjp_symbol_mapper (unrelated but it should be renamed) to set the needed tags taken from the input symbol argument:
def vjp_symbol_mapper(symbol: prims.Symbol, *args, **kwargs):

Anyways at this stage, it would be valuable to write down a design document and review it with a larger group.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants