Skip to content

jax.lax.scan strange results compared to other methdos #14348

Answered by soraros
jecampagne asked this question in Q&A
Discussion options

You must be logged in to vote

Ok, I see the problem.
Your rescaleAbsWeights is a stateful method and it modifies self on every iteration.
This can be observed if you run the following

class ClenshawCurtisQuad:
  ...
  def rescaleAbsWeights(...):
    jax.debug.print('called')
	...
  ...

def body_fun(carry, k):
  ...
  jax.debug.print('{}', quad.absc.sum())
  ...

_, y = jax.lax.scan(body_fun, (rmax, beta, quad150), kvec)
y

The fix is easy, you can just not scan over quad, like

def body_fun(carry, k):
   y = quadIntegral(partial(integrant, k=k, beta=beta), 0., rmax, quad150)
   return carry, y

_, y = jax.lax.scan(body_fun, (), kvec)  # don't need to scan over the constants
y

Replies: 2 comments 22 replies

Comment options

You must be logged in to vote
19 replies
@jecampagne
Comment options

@jecampagne
Comment options

@jecampagne
Comment options

@jecampagne
Comment options

@jecampagne
Comment options

Answer selected by jecampagne
Comment options

You must be logged in to vote
3 replies
@soraros
Comment options

@jecampagne
Comment options

@soraros
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants