-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move to SciPy optimizers, made differentiable through jax.lax.custom_root
#4
Comments
You can easily achieve this in We originally implemented this before Otherwise, the two phase approach is essential a more specific instance of The two phase approach leverages the attractive assumption to solve for the derivatives by simply iterating (which is guaranteed to converge in this case) while One thing to note is that, last we checked, it's not obvious to "extract" intermediate information from the backwards pass using I hope this helps clarify the differences and similarities between |
Fyi, we're discussing re-implementing |
Thanks for letting us know @gehring! From what you’re saying, it sounds like one could outright use SciPy optimisers/root-finders without writing the corresponding jax code, which would cover the use case in this issue (and another one that I haven’t written on GH yet) — my thoughts are only positive! @lukasheinrich, do you think this could end up baked straight into pyhf’s optimise module? (Also thanks for your previous comment @gehring, it clarified some things for me!) |
I think you are referring only to the solver and, in that case, yes. Note that you can already do that with However, just to make sure thinks are extra clear, I'll mention that you will still need to provide For an example of what I mean in the fixed-point case (e.g., using |
The current implementation of the maximum likelihood fits uses gradient descent to converge to the optimal parameter values. In principle, for comparison with the optimization implementation in pyhf, and for more robust minimization, switching to SciPy optimizers is preferred.
To do this, one needs to differentiate through the optimizer using implicit differentiation. It's probably possible to do this using fax like we do now, but this issue on the jax repo discusses the possibility of wrapping SciPy optimizers using
jax.lax.custom_root
, which would remove a dependency, and make for (probably) more simplistic code.The text was updated successfully, but these errors were encountered: