Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG #394 fix prox and regularization #544

Merged
merged 31 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
298807e
bug(regularization): fix calculation and confusion regarding threshold
himkwtn Aug 14, 2024
a2df39a
bug: fix calling regularization with wrong argument
himkwtn Aug 14, 2024
b436eab
CLN: linting
himkwtn Aug 14, 2024
070d0c0
ENH: unit tests for get_regularization
himkwtn Aug 14, 2024
58a0ce2
ENH: unit tests for get_prox
himkwtn Aug 16, 2024
14b564a
ENH: add typings for regularization and prox
himkwtn Aug 19, 2024
c2bb3ee
bug: fix duplicate test name
himkwtn Aug 19, 2024
3cad8f6
CLN: remove debug variable
himkwtn Aug 19, 2024
4509ea2
BUG: fix l0 reg
himkwtn Aug 20, 2024
b35f24c
ENH: improve get_regularization test cases
himkwtn Aug 20, 2024
86fef20
BUG: fix cvxpy regularization calculation
himkwtn Aug 20, 2024
9ff815a
ENH: improve get prox/reg
himkwtn Aug 23, 2024
89ead0f
ENH: remove cad
himkwtn Aug 23, 2024
658c300
ENH: create unit test for get_regularization shape validation
himkwtn Aug 23, 2024
1387b81
CLN: clean up util code
himkwtn Aug 26, 2024
ded5b29
revert test cases and add thresholds transpose
himkwtn Aug 26, 2024
d971225
DOC: constrained sr3 method update doc string
himkwtn Aug 26, 2024
de58118
CLN: fix linting
himkwtn Aug 26, 2024
ba3798e
CLN: fix linting
himkwtn Aug 26, 2024
6928455
CLN: fix linting
himkwtn Aug 26, 2024
05ee9e6
change shape validation
himkwtn Aug 28, 2024
2981f77
CLN: merge weighted and non-weighted prox/reg fn
himkwtn Aug 28, 2024
c2237bf
clean up docstring
himkwtn Aug 28, 2024
5ee3bcc
CLN: fix constrained SR3 docstring
himkwtn Aug 30, 2024
870525d
BUG: fix example for using thresholds in SR3
himkwtn Aug 30, 2024
a0475aa
fix according to comments
himkwtn Sep 3, 2024
8a2afeb
fix lint
himkwtn Sep 5, 2024
a513d76
test weighted_prox
himkwtn Sep 5, 2024
98df181
publish notebook
himkwtn Sep 5, 2024
90967e0
manually fix notebook
himkwtn Sep 9, 2024
01ce00d
refactor ConstrainedSR3._calculate_penalty
himkwtn Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ class ConstrainedSR3(SR3):

thresholder : string, optional (default 'l0')
Regularization function to use. Currently implemented options
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation), 'weighted_l0' (weighted l0 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l0' (weighted l0 norm), 'weighted_l1' (weighted l1 norm),
and 'weighted_l2' (weighted l2 norm).

max_iter : int, optional (default 30)
Maximum iterations of the optimization algorithm.
Expand Down Expand Up @@ -270,16 +270,31 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse):
return inv1.dot(rhs)

@staticmethod
def _calculate_penalty(regularizer, lam, xi: cp.Variable) -> cp.Expression:
regularizer = regularizer.lower()
if regularizer == "l1":
return lam * cp.sum(cp.abs(xi))
elif regularizer == "weighted_l1":
return cp.sum(cp.multiply(np.ravel(lam), cp.abs(xi)))
elif regularizer == "l2":
return lam * cp.sum(xi**2)
elif regularizer == "weighted_l2":
return cp.sum(cp.multiply(np.ravel(lam), xi**2))
def _calculate_penalty(
regularization: str, regularization_weight, xi: cp.Variable
) -> cp.Expression:
"""
Args:
-----
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' |
'l2' | 'weighted_l2'
regularization_weight: float | np.array, can be a scalar
or an array of shape (n_targets, n_features)
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
xi: cp.Variable

Returns:
--------
cp.Expression
"""
regularization = regularization.lower()
if regularization == "l1":
return regularization_weight * cp.sum(cp.abs(xi))
elif regularization == "weighted_l1":
return cp.sum(cp.multiply(np.ravel(regularization_weight), cp.abs(xi)))
elif regularization == "l2":
return regularization_weight * cp.sum(xi**2)
elif regularization == "weighted_l2":
return cp.sum(cp.multiply(np.ravel(regularization_weight), xi**2))

def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
Expand Down
9 changes: 4 additions & 5 deletions pysindy/optimizers/sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class SR3(BaseOptimizer):

thresholder : string, optional (default 'L0')
Regularization function to use. Currently implemented options
are 'L0' (L0 norm), 'L1' (L1 norm), 'L2' (L2 norm) and 'CAD' (clipped
absolute deviation). Note by 'L2 norm' we really mean
are 'L0' (L0 norm), 'L1' (L1 norm) and 'L2' (L2 norm).
Note by 'L2 norm' we really mean
the squared L2 norm, i.e. ridge regression

trimming_fraction : float, optional (default 0.0)
Expand Down Expand Up @@ -179,10 +179,9 @@ def __init__(
"weighted_l0",
"weighted_l1",
"weighted_l2",
"cad",
):
raise NotImplementedError(
"Please use a valid thresholder, l0, l1, l2, cad, "
"Please use a valid thresholder, l0, l1, l2, "
"weighted_l0, weighted_l1, weighted_l2."
)
if thresholder[:8].lower() == "weighted" and thresholds is None:
Expand Down Expand Up @@ -213,7 +212,7 @@ def __init__(
self.thresholds = thresholds
self.nu = nu
if thresholds is not None:
self.lam = thresholds**2 / (2 * nu)
self.lam = thresholds.T**2 / (2 * nu)
else:
self.lam = threshold**2 / (2 * nu)
self.tol = tol
Expand Down
3 changes: 1 addition & 2 deletions pysindy/optimizers/stable_linear_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class StableLinearSR3(ConstrainedSR3):

thresholder : string, optional (default 'l1')
Regularization function to use. Currently implemented options
are 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation),
are 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
Note that the thresholder must be convex here.

Expand Down
14 changes: 0 additions & 14 deletions pysindy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from .base import get_prox
from .base import get_regularization
from .base import print_model
from .base import prox_cad
from .base import prox_l0
from .base import prox_l1
from .base import prox_l2
from .base import prox_weighted_l0
from .base import prox_weighted_l1
from .base import prox_weighted_l2
from .base import reorder_constraints
from .base import supports_multiple_targets
from .base import validate_control_variables
Expand Down Expand Up @@ -66,13 +59,6 @@
"get_prox",
"get_regularization",
"print_model",
"prox_cad",
"prox_l0",
"prox_weighted_l0",
"prox_l1",
"prox_weighted_l1",
"prox_l2",
"prox_weighted_l2",
"reorder_constraints",
"supports_multiple_targets",
"validate_control_variables",
Expand Down
197 changes: 107 additions & 90 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import repeat
from typing import Callable
from typing import Sequence
from typing import Union

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -152,130 +153,146 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def prox_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for L0 regularization."""
threshold = np.sqrt(2 * lam)
return x * (np.abs(x) > threshold)


def prox_weighted_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for weighted l0 regularization."""
y = np.zeros(np.shape(x))
threshold = np.sqrt(2 * lam)
for i in range(lam.shape[0]):
for j in range(lam.shape[1]):
y[i, j] = x[i, j] * (np.abs(x[i, j]) > threshold[i, j])
return y


def prox_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - lam, 0)


def prox_weighted_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - lam, np.zeros(x.shape))


def prox_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * lam)

def validate_prox_and_reg_inputs(func, regularization):
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
raise ValueError(
f"'regularization_weight' must be an array of shape {x.shape}."
)
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Invalid shape for 'regularization_weight': \
{weight_shape}. Must be the same shape as x: {x.shape}."
)
else:
if not isinstance(regularization_weight, (int, float)) and (
isinstance(regularization_weight, np.ndarray)
and regularization_weight.shape not in [(1, 1), (1,)]
):
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("'regularization_weight' must be a scalar")
return func(x, regularization_weight)

def prox_weighted_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * lam)
return wrapper


# TODO: replace code block with proper math block
def prox_cad(x: NDArray[np.float64], lam: NDArray[np.float64]):
def get_prox(
regularization: str,
) -> Callable[
[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], NDArray[np.float64]
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
]:
"""
Proximal operator for CAD regularization

.. code ::

prox_cad(z, a, b) =
0 if |z| < a
sign(z)(|z| - a) if a < |z| <= b
z if |z| > b

Entries of :math:`x` smaller than a in magnitude are set to 0,
entries with magnitudes larger than b are untouched,
and entries in between have soft-thresholding applied.
Args:
-----
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | 'l2' | 'weighted_l2'

For simplicity we set :math:`b = 5*a` in this implementation.
Returns:
--------
proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
and returns an array of shape (n_targets, n_features)
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
"""
lower_threshold = lam
upper_threshold = 5 * lam
return prox_l0(x, upper_threshold) + prox_l1(x, lower_threshold) * (
np.abs(x) < upper_threshold
)

def prox_l0(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L0 regularization."""
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for weighted l0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)
himkwtn marked this conversation as resolved.
Show resolved Hide resolved

def prox_l1(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_l2(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

def prox_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

def get_prox(
regularization: str,
) -> Callable[[NDArray[np.float64], NDArray[np.float64]], NDArray[np.float64]]:
prox = {
"l0": prox_l0,
"weighted_l0": prox_weighted_l0,
"l1": prox_l1,
"weighted_l1": prox_weighted_l1,
"l2": prox_l2,
"weighted_l2": prox_weighted_l2,
"cad": prox_cad,
}
if regularization.lower() in prox:
return prox[regularization.lower()]
else:
raise NotImplementedError("{} has not been implemented".format(regularization))


def regularization_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
return lam * np.count_nonzero(x)

regularization = regularization.lower()
return validate_prox_and_reg_inputs(prox[regularization], regularization)
himkwtn marked this conversation as resolved.
Show resolved Hide resolved

def regualization_weighted_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam[np.nonzero(x)])

def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], float]:
"""
Args:
-----
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | 'l2' | 'weighted_l2'

def regularization_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * np.abs(x))


def regualization_weighted_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * np.abs(x))

Returns:
--------
regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
and returns a float
"""

def regularization_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * x**2)
def regularization_l0(x: NDArray[np.float64], regularization_weight: np.float64):
return regularization_weight * np.count_nonzero(x)

def regualization_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight[np.nonzero(x)])

def regualization_weighted_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * x**2)
def regularization_l1(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * np.abs(x))

def regualization_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * np.abs(x))

def regularization_cad(x: NDArray[np.float64], lam: NDArray[np.float64]):
# dummy function
return 0
def regularization_l2(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * x**2)

def regualization_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * x**2)

def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], NDArray[np.float64]], float]:
regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regualization_weighted_l0,
"l1": regularization_l1,
"weighted_l1": regualization_weighted_l1,
"l2": regularization_l2,
"weighted_l2": regualization_weighted_l2,
"cad": regularization_cad,
}
if regularization.lower() in regularization_fn:
return regularization_fn[regularization.lower()]
else:
raise NotImplementedError("{} has not been implemented".format(regularization))
regularization = regularization.lower()
return validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)


def capped_simplex_projection(trimming_array, trimming_fraction):
Expand Down
Loading
Loading