-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix numpy expand and make numpy backend jittable #317
Conversation
@@ -326,7 +326,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs): | |||
# Then g is an unbiased estimator of f in value and all derivatives. | |||
# In the special case f = detach(f), we can simplify to | |||
# g = delta(x=x0) |f|. | |||
if flat_logits.requires_grad: | |||
if hasattr(flat_logits, "requires_grad") and flat_logits.requires_grad: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if for numpy backend, we should exercise if
or else
branch here.
@@ -326,7 +326,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs): | |||
# Then g is an unbiased estimator of f in value and all derivatives. | |||
# In the special case f = detach(f), we can simplify to | |||
# g = delta(x=x0) |f|. | |||
if flat_logits.requires_grad: | |||
if hasattr(flat_logits, "requires_grad") and flat_logits.requires_grad: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Does JAX have a
.detach()
method, as needed in the DiCE branch? Does jax have a.requires_grad
checker? - Should we have three backends, "numpy", "jax", and "torch"? In that case, "numpy" should use the second branch and jax should use the first (i.e. DiCE) branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have three backends, "numpy", "jax", and "torch"?
+1 to this, with "numpy" as the default backend - that way dependencies heavier than numpy are optional and backend-specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does JAX have a .detach() method, as needed in the DiCE branch? Does jax have a .requires_grad checker?
The analog of detach
is lax.stop_gradient, but I don't think there is a way to check for requires_grad
flag (and probably wouldn't work under jit anyways). I suspect that we may have to write this as a JAX custom primitive and specify the behavior under ad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank all! I can understand the situation now. I'll address this in an upcoming PR, where I add global BACKEND "numpy"/"torch"/"jax". As @neerajprad pointed out, we can use lax.stop_gradient
in this branch.
About requires_grad
, I think we don't need to worry about it. If JAX users don't want to take grad
here, the overhead of if
branch is pretty small. Please let me know if I miss crucial points here (e.g. sb_inputs
is different from batch_inputs
...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so as I understand, in this branch you will
if backend == "torch" and not flat_logits.requires_grad:
# use the shorcut
else: # numpy, jax, and torch with grad
# use the DiCE version
and in a follow-up PR you can add "numpy"/"jax"/"torch" and optimize the numpy branch?
if backend == "numpy" or (backend == "torch" and not flat_logits.requires_grad):
# use the shorcut
else: # jax, and torch with grad
# use the DiCE version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but let's keep this PR as-is because the logic here (for backends other than torch
) is not exercised in any test yet.
Following fixes are detected while I did profiling GaussianHMM in #315.
-1
in expand shape for numpy backend using @fritzo 's suggestionUnshapedArray
is replaced byjax.core.Tracer
. A device arrayx
will becomeTracer
under jit/vmap/pmap. This tracer has an attributex.aval
, which is UnshapedArray.eager_cat_homogeneous
backend agnosticFollow-up PR: