Replies: 1 comment 4 replies
-
I like pytrees a lot, but in a PoC I was working on two years ago, I found that the jit time of using pytrees of many small arrays was very expensive. I found that it could take up to hours to jit a |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
This is a quickly thrown together draft of a proof-of-concept that I think would be both cool to see, and possible to prototype in a short timespan, mainly to see what's possible!
I don't imagine this kind of technical prototype as a replacement for
pyhf
, but I think that there could be lessons learned by working on this functionality that could be upstreamed to the current codebase. Hope these ideas are useful :)What?
PyTrees are JAX's native representation of structured data. Many native Python types are already PyTrees:
dict
,list
,NamedTuple
etc. being some examples. Aclass
, by default, is not a PyTree, but thanks to libraries like equinox and simple-pytree, one can just inherit from the relevant class provided by these modules, and suddenly your class is also a PyTree (for example,pyhf.Model
!).Now, why are PyTrees useful? The main reason is that they're compatible with JAX function transforms, meaning that functions that take in PyTrees as arguments can be used with
jax.jit
,jax.vmap
,jax.grad
etc., which opens up a whole range of cool stuff. I'll try to summarize exactly how this could be useful forpyhf
in the following sections.The main idea here is basically: if a module is designed with the right JAX building blocks, you can arbitrarily use any of its cool features on any of your programs! (this is the design principle behind modules like
diffrax
, as pointed out to me by the author).Why?
vectorizable limits/hypotest/anything!
Vectorization/batching becomes super easy. E.g. for hypotest:
This should be particularly nice for limit scanning -- this is currently a Python for-loop in the code:
pyhf/src/pyhf/infer/intervals/upper_limits.py
Lines 189 to 192 in 5161749
Access to
jax.jit
All workflows that involve PyTrees get access to just-in-time compilation, which, like all JAX transforms, can be used as a function transformation or as a decorator:
As a very small anecdotal benchmark, I messaged @matthewfeickert a couple years back claiming that
logpdf
could already be made 10x faster if the method is wrapped withjax.jit
(I wonder if he remembers this :p)True end-to-end differentiability
If we have PyTrees, all the differentiable model building stuff is super easy, resolving a number of issues and PRs I've made (e.g. #882, #1912). I found this out via experimenting to resolve this jaxopt issue, which has basically stopped
relaxed
from working on modern JAX versions.See this conversation to see how easy things are with PyTrees!
How?
This project feels much more realistic after looking through @pfackeldey's work on
dilax
, which uses all the main building blocks I've discussed here -- I even managed to differentiate over model construction without him knowing aboutneos
! My hunch for a realistic path forward is to use the same design principles, but to build things up in a morepyhf
-like way such that the resulting API would look more or less the same. It's a great launchpad for this effort, if it is of interest!Why haven't I done this yet?
I've toyed with this for a little while, but I don't have enough fundamental understanding of the best way to compose all the different parts of histfactory (mostly ignorance around systematics handling + computation). So with a little advice and a couple of people looking at this (e.g. as a hackathon project for PyHEP.dev, which I think most mentioned here should be attending), we could probably make this happen!
Oh, and I'm also technically not in the field right now. But I have 10% bandwidth that I'm allowed to use to write things like this ;)
My final, selfish reason for this is this is the make-or-break ingredient for things like
neos
being possible for anyone that wants to try it in the future with a HiFa-based likelihood (i've already had a couple people talk to me about that).So yeah -- hope this is useful, and maybe it turns into something cool!
Beta Was this translation helpful? Give feedback.
All reactions