Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Are weighted norms in SR3 classes computed incorrectly? #394

Closed
Jacob-Stevens-Haas opened this issue Aug 15, 2023 · 3 comments · Fixed by #544
Closed

[BUG] Are weighted norms in SR3 classes computed incorrectly? #394

Jacob-Stevens-Haas opened this issue Aug 15, 2023 · 3 comments · Fixed by #544
Assignees

Comments

@Jacob-Stevens-Haas
Copy link
Collaborator

Jacob-Stevens-Haas commented Aug 15, 2023

SR3 and its subclasses use weighted norms in two types of expressions; in both cases, it appears that those classes get it wrong.
calculated using cvxpy

cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
# or
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2

This is odd because the xi is a 1d vector, as is np.ravel(...), which means both are probably computing the norm of a scalar, not a vector.

calculated using pysindy.utils.base.get_regularization()
This function doesn't have any unit tests, so it's tricky to see how it expects inputs, but this line indicates that the thresholds may not be properly applied (self.reg is the lambda function returned by get_regularization, coeff_full is (n_features, n_targets) and self.thresholds is (n_targets, n_features)). It seems like elementwise multiplication, rather than matrix multiplication, is what's desired here.

Reproducing code example:

I haven't built a small regression problem to prove these things are wrong, because I haven't worked with SR3 other than speeding up the tests in #393. But its easy to see that get_regularization() returns the norm of the inner product, even though weights are (I think?) supposed to be elementwise

import pysindy as ps
import numpy as np

weights = np.ones(2)
arr = np.ones(2)

reg1 = ps.utils.get_regularization("weighted_l1")
reg2 = ps.utils.get_regularization("weighted_l2")

print(reg1(arr, weights))
print(reg2(arr, weights))

Error message:

Both show that the norm is 2.

Thoughts

Broadly speaking, here's what are probably the next steps. Caveat, very little of my work deals with SR3, and none with SR3 weighted thresholding, so I may get through step 1 at the most.

  1. Create a toy regression to determine if this problem actually exists or I'm reading code wrongly.
  2. Add annotations/docstrings/tests to the parts of pysindy.utils.base that SR3 uses.
  3. Refactor the SR3 class and subclasses to make each step testable.

@briandesilva , @akaptano , @kpchamp, I know this is old code, but if you have any time in the coming months to review, I'd appreciate it.

@Jacob-Stevens-Haas Jacob-Stevens-Haas changed the title [BUG] Weighted norms in all SR3 classes may be wrong [BUG] Are weighted norms in SR3 classes computed incorrectly? Aug 16, 2023
@Jacob-Stevens-Haas
Copy link
Collaborator Author

@himkwtn this is a decent issue that is a little bit mathy, but beyond the good first issue tag, if you want to take it on

@himkwtn himkwtn self-assigned this Aug 5, 2024
@himkwtn
Copy link
Collaborator

himkwtn commented Aug 5, 2024

I found that StableLinearSR3 might be calculating the l2 regularization term incorrectly.

cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi)

If I understand correctly, it should be self.threshold * cp.norm2(xi) ** 2 ** 2 right?

@Jacob-Stevens-Haas
Copy link
Collaborator Author

as I understand it, this line is in a conditional for weighted l2 norm. So it should be something like

cp.norm2(np.sqrt(self.thresholds) * np.ravel(xi))

@Jacob-Stevens-Haas Jacob-Stevens-Haas linked a pull request Sep 10, 2024 that will close this issue
Jacob-Stevens-Haas pushed a commit that referenced this issue Sep 10, 2024
Now, prox and regularization have a more plain API: weights are either scalars or must
match the shape of the optimization variable.  This removes weird use cases that only
worked because of broadcasting, which constrained our ability to simplify

It also fixes the calculation of regularizers in SR3 _calculate_penalty, with the aim that
this will also be replaced with get_prox and get_regularization when they are able to
handle CVXPY (or JAX) expressions (arrays)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants