Skip to content

How does jit compute cache keys and hits for generic PyTrees? #14996

Answered by jakevdp
KeAWang asked this question in Q&A
Discussion options

You must be logged in to vote

The rough mental model is this: if you pass a pytree y to the function, then internally the JIT machinery computes

children, treedef = tree_util.tree_flatten(y)

The children are treated as dynamic arguments, while treedef (and all the pytree auxilary data it contains) is treated as a static argument, with its hash being rolled into the JIT cache key.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@KeAWang
Comment options

@jakevdp
Comment options

Answer selected by KeAWang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants