-
Hello, everyone! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This doesn't seem like a jax-specific question, and it's hard to give good advice without knowing more about your problem. In general, applying constraints to the outputs of a neural net is an active area of research. Your best bet is probably to use an idea called differentiable optimization, in which the last layer of your neural net is an optimization problem with constraints. Implicit differentiation can allow you to differentiate through the solver without unrolling the loop (which is very memory expensive). JAXopt might be a good place to start. Depending on the structure of your constraints, you might simply be able to project onto the feasible set, and differentiate through this projection step. |
Beta Was this translation helpful? Give feedback.
This doesn't seem like a jax-specific question, and it's hard to give good advice without knowing more about your problem. In general, applying constraints to the outputs of a neural net is an active area of research. Your best bet is probably to use an idea called differentiable optimization, in which the last layer of your neural net is an optimization problem with constraints. Implicit differentiation can allow you to differentiate through the solver without unrolling the loop (which is very memory expensive). JAXopt might be a good place to start.
Depending on the structure of your constraints, you might simply be able to project onto the feasible set, and differentiate through this p…