-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move from interpret_trace->transform->construct_trace to working with traces directly #1138
Comments
We have tags for Symbols lightning-thunder/thunder/core/symbol.py Line 133 in 314e748
I see that "tags on proxies" were added in #1048 but there's no documentation about the intended usage lightning-thunder/thunder/core/proxies.py Line 109 in 314e748
Can we unify tags on symbols and tags on proxies?
What does it mean to preserve tags on proxies when in general the transformation is allowed to produce completely new and different outputs and the number of outputs can be different? If we just had tags on symbols I would suggest modifying the trace interpreter part of the code that calls the symbol with a context manager with current tags to be set on newly created symbols: lightning-thunder/thunder/core/trace_interpreter.py Lines 63 to 66 in 314e748
Are tags on proxies queriable as any other metadata (shape, dtype, device) at the jit-tracing time? Should the grad rule functions be modified then to copy the special tags from inputs to outputs? |
To my mind, we conceptually need tags on proxies because properties like "STATIC_MEMORY_LOCATION" or "DONT_SAVE_FOR_BACKWARD" are properties of the proxies. The drawback of the style proposed here is that it fundamentally only works at the top level, so any subsymbol's additional info is lost. That said, the excercise in #1164 also left me wondering about the apparent tower of complexity in the autograd bits (I am not saying it is unnecessary complexity, just that it is quite nested). It certainly seems to be a source of inconsistencies we have been seen in transform for execution etc. |
Thinking about @IvanYashchuk 's comment more, I'm warming to the idea that maybe putting a rematerialization tag on the bsym and inheriting it to the subsymbols would be better than putting it on the proxies. WDYT? |
It can work without much modification to the current code if BoundSymbol creation inside lightning-thunder/thunder/core/symbol.py Line 323 in 59467aa
and inside the trace interpreter before invoking Symbol.__call__ the current tags are set. It could be done for example by wrapping the returned function from vjp_symbol_mapper (unrelated but it should be renamed) to set the needed tags taken from the input symbol argument:lightning-thunder/thunder/core/transforms.py Line 2463 in 59467aa
Anyways at this stage, it would be valuable to write down a design document and review it with a larger group. |
The autocast and gradient-related transforms use
interpret_trace->transform->construct_trace
this drops information, e.g. tags on the proxies.
So this issue is making them work with the traces directly.
My idea here is NOT to stop calling symbols but to construct the new trace closer to where we do.
The text was updated successfully, but these errors were encountered: