You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JaxStackTraceBeforeTransformation: TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <cell line: 28>()
26
27 print(test_loss(jnp.array([-0.5]))) # -1.0625
---> 28 print(jax.grad(test_loss)(jnp.array([1.]))) # error
[... skipping hidden 12 frame]
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in solver_fun_bwd(tup, cotangent)
234
235 # Compute VJPs w.r.t. args.
--> 236 vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
237 args=ba_args[1:], cotangent=cotangent, solve=solve)
238 # Prepend None as the vjp for init_params.
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in root_vjp(optimality_fun, sol, args, cotangent, solve)
58 return optimality_fun(sol, *args)
59
---> 60 _, vjp_fun_sol = jax.vjp(fun_sol, sol)
61
62 # Compute the multiplication A^T u = (u^T A)^T.
[... skipping hidden 7 frame]
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in fun_sol(sol)
56 def fun_sol(sol):
57 # We close over the arguments.
---> 58 return optimality_fun(sol, *args)
59
60 _, vjp_fun_sol = jax.vjp(fun_sol, sol)
[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in optimality_fun(params, params_obj, params_eq, params_ineq)
352 primal_var, eq_dual_var, ineq_dual_var = params
353
--> 354 stationarity = grad_fun(primal_var, params_obj)
355
356 if eq_dual_var is not None:
[... skipping hidden 10 frame]
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <lambda>(z, params_obj)
1 import jaxopt
2
----> 3 fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
4 matvec_A = lambda params_A, z: (z, )
5 solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)
TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'
Does optimality_fun/grad_fun not alter the original function fun to handle tangents properly?
I can successful get the gradient using the (Q, c) and matvec_Q paths.
I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including fun.
The text was updated successfully, but these errors were encountered:
Relevant traceback:
Does
optimality_fun
/grad_fun
not alter the original functionfun
to handle tangents properly?I can successful get the gradient using the
(Q, c)
andmatvec_Q
paths.I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including
fun
.The text was updated successfully, but these errors were encountered: