Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement robust U-turn check #3605

Merged
merged 18 commits into from
Oct 17, 2019
Merged
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,18 @@ def extend(self, direction):
if direction > 0:
tree, diverging, turning = self._build_subtree(
self.right, self.depth, floatX(np.asarray(self.step_size)))
leftmost_begin, leftmost_end = self.left, self.right
rightmost_begin, rightmost_end = tree.left, tree.right
leftmost_p_sum = self.p_sum
rightmost_p_sum = tree.p_sum
self.right = tree.right
else:
tree, diverging, turning = self._build_subtree(
self.left, self.depth, floatX(np.asarray(-self.step_size)))
leftmost_begin, leftmost_end = tree.right, tree.left
rightmost_begin, rightmost_end = self.left, self.right
leftmost_p_sum = tree.p_sum
rightmost_p_sum = self.p_sum
self.left = tree.right

self.depth += 1
Expand All @@ -271,9 +279,16 @@ def extend(self, direction):
self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.p_sum[:] += tree.p_sum

left, right = self.left, self.right
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
# Additional turning check only when tree depth > 0 to avoid redundant work
if self.depth > 0:
left, right = self.left, self.right
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
p_sum1 = leftmost_p_sum + rightmost_begin.p
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
p_sum2 = leftmost_end.p + rightmost_p_sum
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
turning = (turning | turning1 | turning2)

return diverging, turning

Expand Down Expand Up @@ -324,6 +339,13 @@ def _build_subtree(self, left, depth, epsilon):
if not (diverging or turning):
p_sum = tree1.p_sum + tree2.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
# Additional U turn check only when depth > 1 to avoid redundant work.
if depth - 1 > 0:
p_sum1 = tree1.p_sum + tree2.left.p
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
p_sum2 = tree1.right.p + tree2.p_sum
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
turning = (turning | turning1 | turning2)

log_size = np.logaddexp(tree1.log_size, tree2.log_size)
if logbern(tree2.log_size - log_size):
Expand Down