jax.lax.scan strange results compared to other methdos #14348
-
Hello, Here a snippet that show a strange behaviour of jax.lax.scan may be due to wrong PyTree structure.
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial
from jax import grad, jit, vmap
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class ClenshawCurtisQuad:
"""
Clenshaw-Curtis quadrature computed by FFT
"""
def __init__(self,order, absc=None, absw=None, errw=None):
# 2n-1 quad
self._order = jnp.int64(2*order-1)
if absc==None or absw==None or errw==None:
self._absc, self._absw, self._errw = self.ComputeAbsWeights()
else:
self._absc, self._absw, self._errw = absc, absw, errw
self.rescaleAbsWeights()
def __str__(self):
return f"xi={self._absc}\n wi={self._absw}\n errwi={self._errw}"
@property
def absc(self):
return self._absc
@property
def absw(self):
return self._absw
@property
def errw(self):
return self._errw
def ComputeAbsWeights(self):
x,wx = self.absweights(self._order)
nsub = (self._order+1)//2
xSub, wSub = self.absweights(nsub)
errw = jnp.array(wx, copy=True) # np.copy(wx)
errw=errw.at[::2].add(-wSub) # errw[::2] -= wSub
return x,wx,errw
def absweights(self,n):
points = -jnp.cos((jnp.pi * jnp.arange(n)) / (n - 1))
if n == 2:
weights = jnp.array([1.0, 1.0])
return points, weights
n -= 1
N = jnp.arange(1, n, 2)
length = len(N)
m = n - length
v0 = jnp.concatenate([2.0 / N / (N - 2), jnp.array([1.0 / N[-1]]), jnp.zeros(m)])
v2 = -v0[:-1] - v0[:0:-1]
g0 = -jnp.ones(n)
g0 = g0.at[length].add(n) # g0[length] += n
g0 = g0.at[m].add(n) # g0[m] += n
g = g0 / (n ** 2 - 1 + (n % 2))
w = jnp.fft.ihfft(v2 + g)
###assert max(w.imag) < 1.0e-15
w = w.real
if n % 2 == 1:
weights = jnp.concatenate([w, w[::-1]])
else:
weights = jnp.concatenate([w, w[len(w) - 2 :: -1]])
#return
return points, weights
def rescaleAbsWeights(self, xInmin=-1.0, xInmax=1.0, xOutmin=0.0, xOutmax=1.0):
"""
Translate nodes,weights for [xInmin,xInmax] integral to [xOutmin,xOutmax]
"""
deltaXIn = xInmax-xInmin
deltaXOut= xOutmax-xOutmin
scale = deltaXOut/deltaXIn
self._absw *= scale
tmp = jnp.array([((xi-xInmin)*xOutmax
-(xi-xInmax)*xOutmin)/deltaXIn for xi in self._absc])
self._absc=tmp
#### PyTree methods....
def tree_flatten(self):
children = (self._order, self._absc, self._absw, self._errw)
aux_data = {}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
order, absc, absw, errw = children
return cls(order=order, absc=absc, absw=absw, errw=errw) # doit respected la signature de __init__
@partial(jit, static_argnums=(0,3))
def quadIntegral(f,a,b,quad):
a = jnp.atleast_1d(a)
b = jnp.atleast_1d(b)
d = b-a
xi = a[jnp.newaxis,:]+ jnp.einsum('i...,k...->ik...',quad.absc,d)
xi = xi.squeeze()
fi = f(xi)
S = d * jnp.einsum('i...,i...',quad.absw,fi)
return S.squeeze()#, d*jnp.einsum('i...,i...',quad.errw,fi) Here a function to integrate along def integrant(x, k, beta):
return x * jnp.power(1+x**2,-beta) * jnp.exp(-k*x) Some parameters rmax=10.
beta=2.
quad150=ClenshawCurtisQuad(150)
kvec = jnp.linspace(20.,30.,10) Here a simple for loop along for k in kvec:
print(k, quadIntegral(partial(integrant, k=k, beta=beta),0., rmax, quad150)) these gives integrale values that are in agreement with Mathematica
Now, the for loop is not what a user of JAX wants at the end of the day is'nt it. Here are some partial spectialisation to fix some parameters and get a 'x' only function def fCC(k,beta,rmax, quad):
return quadIntegral(partial(integrant, k=k, beta=beta),0., rmax, quad)
fCCbis = partial(fCC,beta=beta,rmax=rmax, quad=quad150) Using jax.vmap(fCCbis)(kvec) gives correct results
furthermore I can jit it jax.jit(jax.vmap(fCCbis))(kvec) which gives the same correct results. Let now try it with def body_fun(carry, k):
rmax, beta, quad = carry
y = quadIntegral(partial(integrant, k=k, beta=beta),0., rmax, quad)
return carry, y
_, y = jax.lax.scan(body_fun, (rmax,beta, quad150), kvec)
y This leads to completly wrong values
One can mimic the def scan_no_jit(f, init, xs):
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, jnp.stack(ys)
@partial(jit, static_argnums=0)
def scan_jit(f, init, xs):
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, jnp.stack(ys) The non-jitted version leads to correct results _, y = scan_no_jit(body_fun, (rmax,beta,quad150), kvec)
#DeviceArray([0.00242997, 0.00218699, 0.00197849, 0.00179828, 0.00164149,
# 0.00150425, 0.00138345, 0.00127658, 0.00118158, 0.00109676], dtype=float64) BUT the jitted one gives wrong results AND different of the scan version _, y = scan_jit(body_fun, (rmax,beta,quad150), kvec)
#DeviceArray([1.33801252e-47, 4.90743720e-50, 1.80464991e-52,
# 6.65220730e-55, 2.45741798e-57, 9.09600903e-60,
# 3.37294859e-62, 1.25283115e-64, 4.66061921e-67,
# 1.73625466e-69], dtype=float64) HOWEVER this is not the end, I can use another simpler integrator which is not a PyTree def simps(f, a, b, N=128):
"""Approximate the integral of f(x) from a to b by Simpson's rule.
"""
if N % 2 == 1:
raise ValueError("N must be an even integer.")
dx = (b - a) / N
x = jnp.linspace(a, b, N + 1)
y = f(x)
S = dx / 3 * jnp.sum(y[0:-1:2] + 4 * y[1::2] + y[2::2], axis=0)
return S and repeat the game and the TWO manual_scan AND the jax.lax. scan gives correct results ! def body_fun2(carry, k):
rmax, beta,_ = carry
y = simps(partial(integrant, k=k, beta=beta),0., rmax, N=256)
return carry, y
_, y = scan_no_jit(body_fun2, (rmax,beta, quad150), kvec) # OK
_, y = scan_jit(body_fun2, (rmax,beta, quad150), kvec) #OK Too
_, y = jax.lax.scan(body_fun2, (rmax,beta, quad150), kvec) #OK Too So ----> I imagine an interaction of Jit/PyTree in the ClenshawCurtisQuad even if the jitted vmap is possibly working. This is not confortable and I would like a new eyes on the above snippet to try to understand the failure... Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 22 replies
-
Ok, I see the problem. class ClenshawCurtisQuad:
...
def rescaleAbsWeights(...):
jax.debug.print('called')
...
...
def body_fun(carry, k):
...
jax.debug.print('{}', quad.absc.sum())
...
_, y = jax.lax.scan(body_fun, (rmax, beta, quad150), kvec)
y The fix is easy, you can just not scan over def body_fun(carry, k):
y = quadIntegral(partial(integrant, k=k, beta=beta), 0., rmax, quad150)
return carry, y
_, y = jax.lax.scan(body_fun, (), kvec) # don't need to scan over the constants
y |
Beta Was this translation helpful? Give feedback.
-
hello @soraros You mention that |
Beta Was this translation helpful? Give feedback.
Ok, I see the problem.
Your
rescaleAbsWeights
is a stateful method and it modifiesself
on every iteration.This can be observed if you run the following
The fix is easy, you can just not scan over
quad
, like