Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix numpy expand and make numpy backend jittable #317

Merged
merged 2 commits into from
Feb 14, 2020

Conversation

fehiepsi
Copy link
Member

Following fixes are detected while I did profiling GaussianHMM in #315.

  • Support -1 in expand shape for numpy backend using @fritzo 's suggestion
  • UnshapedArray is replaced by jax.core.Tracer. A device array x will become Tracer under jit/vmap/pmap. This tracer has an attribute x.aval, which is UnshapedArray.
  • Make eager_cat_homogeneous backend agnostic

Follow-up PR:

  • enhance the speed of Gaussian

@@ -326,7 +326,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs):
# Then g is an unbiased estimator of f in value and all derivatives.
# In the special case f = detach(f), we can simplify to
# g = delta(x=x0) |f|.
if flat_logits.requires_grad:
if hasattr(flat_logits, "requires_grad") and flat_logits.requires_grad:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if for numpy backend, we should exercise if or else branch here.

@@ -326,7 +326,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs):
# Then g is an unbiased estimator of f in value and all derivatives.
# In the special case f = detach(f), we can simplify to
# g = delta(x=x0) |f|.
if flat_logits.requires_grad:
if hasattr(flat_logits, "requires_grad") and flat_logits.requires_grad:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Does JAX have a .detach() method, as needed in the DiCE branch? Does jax have a .requires_grad checker?
  2. Should we have three backends, "numpy", "jax", and "torch"? In that case, "numpy" should use the second branch and jax should use the first (i.e. DiCE) branch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have three backends, "numpy", "jax", and "torch"?

+1 to this, with "numpy" as the default backend - that way dependencies heavier than numpy are optional and backend-specific

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does JAX have a .detach() method, as needed in the DiCE branch? Does jax have a .requires_grad checker?

The analog of detach is lax.stop_gradient, but I don't think there is a way to check for requires_grad flag (and probably wouldn't work under jit anyways). I suspect that we may have to write this as a JAX custom primitive and specify the behavior under ad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank all! I can understand the situation now. I'll address this in an upcoming PR, where I add global BACKEND "numpy"/"torch"/"jax". As @neerajprad pointed out, we can use lax.stop_gradient in this branch.

About requires_grad, I think we don't need to worry about it. If JAX users don't want to take grad here, the overhead of if branch is pretty small. Please let me know if I miss crucial points here (e.g. sb_inputs is different from batch_inputs...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so as I understand, in this branch you will

if backend == "torch" and not flat_logits.requires_grad:
    # use the shorcut
else:  # numpy, jax, and torch with grad
    # use the DiCE version

and in a follow-up PR you can add "numpy"/"jax"/"torch" and optimize the numpy branch?

if backend == "numpy" or (backend == "torch" and not flat_logits.requires_grad):
    # use the shorcut
else:  # jax, and torch with grad
    # use the DiCE version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but let's keep this PR as-is because the logic here (for backends other than torch) is not exercised in any test yet.

@fritzo fritzo merged commit 26afb5d into pyro-ppl:master Feb 14, 2020
@fehiepsi fehiepsi mentioned this pull request Feb 14, 2020
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants