Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG #394 fix prox and regularization #544

Merged
merged 31 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
298807e
bug(regularization): fix calculation and confusion regarding threshold
himkwtn Aug 14, 2024
a2df39a
bug: fix calling regularization with wrong argument
himkwtn Aug 14, 2024
b436eab
CLN: linting
himkwtn Aug 14, 2024
070d0c0
ENH: unit tests for get_regularization
himkwtn Aug 14, 2024
58a0ce2
ENH: unit tests for get_prox
himkwtn Aug 16, 2024
14b564a
ENH: add typings for regularization and prox
himkwtn Aug 19, 2024
c2bb3ee
bug: fix duplicate test name
himkwtn Aug 19, 2024
3cad8f6
CLN: remove debug variable
himkwtn Aug 19, 2024
4509ea2
BUG: fix l0 reg
himkwtn Aug 20, 2024
b35f24c
ENH: improve get_regularization test cases
himkwtn Aug 20, 2024
86fef20
BUG: fix cvxpy regularization calculation
himkwtn Aug 20, 2024
9ff815a
ENH: improve get prox/reg
himkwtn Aug 23, 2024
89ead0f
ENH: remove cad
himkwtn Aug 23, 2024
658c300
ENH: create unit test for get_regularization shape validation
himkwtn Aug 23, 2024
1387b81
CLN: clean up util code
himkwtn Aug 26, 2024
ded5b29
revert test cases and add thresholds transpose
himkwtn Aug 26, 2024
d971225
DOC: constrained sr3 method update doc string
himkwtn Aug 26, 2024
de58118
CLN: fix linting
himkwtn Aug 26, 2024
ba3798e
CLN: fix linting
himkwtn Aug 26, 2024
6928455
CLN: fix linting
himkwtn Aug 26, 2024
05ee9e6
change shape validation
himkwtn Aug 28, 2024
2981f77
CLN: merge weighted and non-weighted prox/reg fn
himkwtn Aug 28, 2024
c2237bf
clean up docstring
himkwtn Aug 28, 2024
5ee3bcc
CLN: fix constrained SR3 docstring
himkwtn Aug 30, 2024
870525d
BUG: fix example for using thresholds in SR3
himkwtn Aug 30, 2024
a0475aa
fix according to comments
himkwtn Sep 3, 2024
8a2afeb
fix lint
himkwtn Sep 5, 2024
a513d76
test weighted_prox
himkwtn Sep 5, 2024
98df181
publish notebook
himkwtn Sep 5, 2024
90967e0
manually fix notebook
himkwtn Sep 9, 2024
01ce00d
refactor ConstrainedSR3._calculate_penalty
himkwtn Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 33 additions & 66 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from scipy.linalg import cho_factor
from sklearn.exceptions import ConvergenceWarning

from ..utils import get_regularization
from ..utils import reorder_constraints
from .sr3 import SR3

Expand Down Expand Up @@ -65,9 +64,9 @@ class ConstrainedSR3(SR3):

thresholder : string, optional (default 'l0')
Regularization function to use. Currently implemented options
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation), 'weighted_l0' (weighted l0 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l0' (weighted l0 norm), 'weighted_l1' (weighted l1 norm),
and 'weighted_l2' (weighted l2 norm).

max_iter : int, optional (default 30)
Maximum iterations of the optimization algorithm.
Expand Down Expand Up @@ -192,7 +191,6 @@ def __init__(
)

self.verbose_cvxpy = verbose_cvxpy
self.reg = get_regularization(thresholder)
self.constraint_lhs = constraint_lhs
self.constraint_rhs = constraint_rhs
self.constraint_order = constraint_order
Expand Down Expand Up @@ -271,20 +269,41 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse):
rhs = rhs.reshape(g.shape)
return inv1.dot(rhs)

@staticmethod
def _calculate_penalty(
regularization: str, regularization_weight, xi: cp.Variable
) -> cp.Expression:
"""
Args:
-----
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' |
'l2' | 'weighted_l2'
regularization_weight: float | np.array, can be a scalar
or an array of shape (n_targets, n_features)
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
xi: cp.Variable

Returns:
--------
cp.Expression
"""
regularization = regularization.lower()
if regularization == "l1":
return regularization_weight * cp.sum(cp.abs(xi))
elif regularization == "weighted_l1":
return cp.sum(cp.multiply(np.ravel(regularization_weight), cp.abs(xi)))
elif regularization == "l2":
return regularization_weight * cp.sum(xi**2)
elif regularization == "weighted_l2":
return cp.sum(cp.multiply(np.ravel(regularization_weight), xi**2))

def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
) -> Tuple[cp.Variable, cp.Expression]:
xi = cp.Variable(var_len)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
return xi, cost
threshold = self.thresholds if self.thresholds is not None else self.threshold
penalty = self._calculate_penalty(self.thresholder, threshold, xi)
return xi, cost + penalty

def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
if self.use_constraints:
Expand Down Expand Up @@ -342,58 +361,6 @@ def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
coef_new = (xi.value).reshape(coef_prev.shape)
return coef_new

def _update_sparse_coef(self, coef_full):
"""Update the regularized weight vector"""
if self.thresholds is None:
return super(ConstrainedSR3, self)._update_sparse_coef(coef_full)
else:
coef_sparse = self.prox(coef_full, self.thresholds.T)
self.history_.append(coef_sparse.T)
return coef_sparse

def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
"""Objective function"""
if q != 0:
print_ind = q % (self.max_iter // 10.0)
else:
print_ind = q
R2 = (y - np.dot(x, coef_full)) ** 2
D2 = (coef_full - coef_sparse) ** 2
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)

if self.thresholds is None:
regularization = self.reg(coef_full, self.threshold**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
else:
regularization = self.reg(coef_full, self.thresholds**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _reduce(self, x, y):
"""
Perform at most ``self.max_iter`` iterations of the SR3 algorithm
Expand Down
60 changes: 22 additions & 38 deletions pysindy/optimizers/sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class SR3(BaseOptimizer):

thresholder : string, optional (default 'L0')
Regularization function to use. Currently implemented options
are 'L0' (L0 norm), 'L1' (L1 norm), 'L2' (L2 norm) and 'CAD' (clipped
absolute deviation). Note by 'L2 norm' we really mean
are 'L0' (L0 norm), 'L1' (L1 norm) and 'L2' (L2 norm).
Note by 'L2 norm' we really mean
the squared L2 norm, i.e. ridge regression

trimming_fraction : float, optional (default 0.0)
Expand Down Expand Up @@ -179,10 +179,9 @@ def __init__(
"weighted_l0",
"weighted_l1",
"weighted_l2",
"cad",
):
raise NotImplementedError(
"Please use a valid thresholder, l0, l1, l2, cad, "
"Please use a valid thresholder, l0, l1, l2, "
"weighted_l0, weighted_l1, weighted_l2."
)
if thresholder[:8].lower() == "weighted" and thresholds is None:
Expand Down Expand Up @@ -212,6 +211,10 @@ def __init__(
self.threshold = threshold
self.thresholds = thresholds
self.nu = nu
if thresholds is not None:
self.lam = thresholds.T**2 / (2 * nu)
else:
self.lam = threshold**2 / (2 * nu)
self.tol = tol
self.thresholder = thresholder
self.reg = get_regularization(thresholder)
Expand Down Expand Up @@ -253,36 +256,20 @@ def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)
if self.thresholds is None:
regularization = self.reg(coef_full, self.threshold**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
else:
regularization = self.reg(coef_full, self.thresholds**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
regularization = self.reg(coef_full, self.lam)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _update_full_coef(self, cho, x_transpose_y, coef_sparse):
"""Update the unregularized weight vector"""
Expand All @@ -293,10 +280,7 @@ def _update_full_coef(self, cho, x_transpose_y, coef_sparse):

def _update_sparse_coef(self, coef_full):
"""Update the regularized weight vector"""
if self.thresholds is None:
coef_sparse = self.prox(coef_full, self.threshold)
else:
coef_sparse = self.prox(coef_full, self.thresholds.T)
coef_sparse = self.prox(coef_full, self.lam * self.nu)
return coef_sparse

def _update_trimming_array(self, coef_full, trimming_array, trimming_grad):
Expand Down
36 changes: 1 addition & 35 deletions pysindy/optimizers/stable_linear_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class StableLinearSR3(ConstrainedSR3):

thresholder : string, optional (default 'l1')
Regularization function to use. Currently implemented options
are 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation),
are 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
Note that the thresholder must be convex here.

Expand Down Expand Up @@ -310,39 +309,6 @@ def _update_A(self, A_old, coef_sparse):
A_temp[r:, :r] = A_old[r:, :r]
return A_temp.T

def _objective(
self, x, y, q, coef_negative_definite, coef_sparse, trimming_array=None
):
"""Objective function"""
if q != 0:
print_ind = q % (self.max_iter // 10.0)
else:
print_ind = q
R2 = (y - np.dot(x, coef_negative_definite)) ** 2
D2 = (coef_negative_definite - coef_sparse) ** 2
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)

regularization = self.reg(
coef_negative_definite,
(self.threshold**2 if self.thresholds is None else self.thresholds**2)
/ self.nu,
)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _reduce(self, x, y):
"""
Perform at most ``self.max_iter`` iterations of the SR3 algorithm
Expand Down
14 changes: 0 additions & 14 deletions pysindy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from .base import get_prox
from .base import get_regularization
from .base import print_model
from .base import prox_cad
from .base import prox_l0
from .base import prox_l1
from .base import prox_l2
from .base import prox_weighted_l0
from .base import prox_weighted_l1
from .base import prox_weighted_l2
from .base import reorder_constraints
from .base import supports_multiple_targets
from .base import validate_control_variables
Expand Down Expand Up @@ -66,13 +59,6 @@
"get_prox",
"get_regularization",
"print_model",
"prox_cad",
"prox_l0",
"prox_weighted_l0",
"prox_l1",
"prox_weighted_l1",
"prox_l2",
"prox_weighted_l2",
"reorder_constraints",
"supports_multiple_targets",
"validate_control_variables",
Expand Down
Loading
Loading