Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when taking gradient wrt parameters in BoxOSQP #588

Open
deasmhumhna opened this issue Apr 1, 2024 · 0 comments
Open

Error when taking gradient wrt parameters in BoxOSQP #588

deasmhumhna opened this issue Apr 1, 2024 · 0 comments

Comments

@deasmhumhna
Copy link

import jaxopt

fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
matvec_A = lambda params_A, z: (z, )
solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

def test_loss(a):
    params_obj = (jnp.atleast_1d(a,),)
    l = (jnp.array([0.]),)
    u = (jnp.array([1.]),)

    init_params = solver.init_params(
        init_x=jnp.array([0.]), 
        params_obj=params_obj, 
        params_eq=None, 
        params_ineq=(l, u)
    )
    sol = solver.run(
        init_params=init_params,
        params_obj=params_obj, 
        params_eq=None, 
        params_ineq=(l, u)
    )
    zopt = sol.params.primal[-1][-1]
    return fun(zopt, params_obj)

print(test_loss(jnp.array([-0.5]))) # -1.125
print(jax.grad(test_loss)(jnp.array([1.]))) # error

Relevant traceback:

JaxStackTraceBeforeTransformation: TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <cell line: 28>()
     26 
     27 print(test_loss(jnp.array([-0.5]))) # -1.0625
---> 28 print(jax.grad(test_loss)(jnp.array([1.]))) # error

    [... skipping hidden 12 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in solver_fun_bwd(tup, cotangent)
    234 
    235       # Compute VJPs w.r.t. args.
--> 236       vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
    237                       args=ba_args[1:], cotangent=cotangent, solve=solve)
    238       # Prepend None as the vjp for init_params.

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in root_vjp(optimality_fun, sol, args, cotangent, solve)
     58     return optimality_fun(sol, *args)
     59 
---> 60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)
     61 
     62   # Compute the multiplication A^T u = (u^T A)^T.

    [... skipping hidden 7 frame]

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in fun_sol(sol)
     56   def fun_sol(sol):
     57     # We close over the arguments.
---> 58     return optimality_fun(sol, *args)
     59 
     60   _, vjp_fun_sol = jax.vjp(fun_sol, sol)

[/usr/local/lib/python3.10/dist-packages/jaxopt/_src/implicit_diff.py](https://localhost:8080/#) in optimality_fun(params, params_obj, params_eq, params_ineq)
    352     primal_var, eq_dual_var, ineq_dual_var = params
    353 
--> 354     stationarity = grad_fun(primal_var, params_obj)
    355 
    356     if eq_dual_var is not None:

    [... skipping hidden 10 frame]

[<ipython-input-24-f2a2a3471989>](https://localhost:8080/#) in <lambda>(z, params_obj)
      1 import jaxopt
      2 
----> 3 fun = lambda z, params_obj: 0.5 * z @ z + params_obj[0] @ z - 1
      4 matvec_A = lambda params_A, z: (z, )
      5 solver = jaxopt.BoxOSQP(matvec_A=matvec_A, fun=fun, tol=1e-5)

TypeError: unsupported operand type(s) for @: 'tuple' and 'tuple'

Does optimality_fun/grad_fun not alter the original function fun to handle tangents properly?

I can successful get the gradient using the (Q, c) and matvec_Q paths.
I can write my actual function using either of these but I imagine this might be difficult for other operations, which I assume is the logic for including fun.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant