Skip to content

Commit

Permalink
ENH: improve get prox/reg
Browse files Browse the repository at this point in the history
- move to private function
- rename lam to regularization_weight
- remove cad
- add docstring
- add validation
  • Loading branch information
himkwtn committed Aug 23, 2024
1 parent 86fef20 commit 9ff815a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 108 deletions.
4 changes: 2 additions & 2 deletions pysindy/optimizers/sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class SR3(BaseOptimizer):
thresholder : string, optional (default 'L0')
Regularization function to use. Currently implemented options
are 'L0' (L0 norm), 'L1' (L1 norm), 'L2' (L2 norm) and 'CAD' (clipped
absolute deviation). Note by 'L2 norm' we really mean
are 'L0' (L0 norm), 'L1' (L1 norm) and 'L2' (L2 norm).
Note by 'L2 norm' we really mean
the squared L2 norm, i.e. ridge regression
trimming_fraction : float, optional (default 0.0)
Expand Down
3 changes: 1 addition & 2 deletions pysindy/optimizers/stable_linear_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class StableLinearSR3(ConstrainedSR3):
thresholder : string, optional (default 'l1')
Regularization function to use. Currently implemented options
are 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation),
are 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
Note that the thresholder must be convex here.
Expand Down
14 changes: 0 additions & 14 deletions pysindy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from .base import get_prox
from .base import get_regularization
from .base import print_model
from .base import prox_cad
from .base import prox_l0
from .base import prox_l1
from .base import prox_l2
from .base import prox_weighted_l0
from .base import prox_weighted_l1
from .base import prox_weighted_l2
from .base import reorder_constraints
from .base import supports_multiple_targets
from .base import validate_control_variables
Expand Down Expand Up @@ -66,13 +59,6 @@
"get_prox",
"get_regularization",
"print_model",
"prox_cad",
"prox_l0",
"prox_weighted_l0",
"prox_l1",
"prox_weighted_l1",
"prox_l2",
"prox_weighted_l2",
"reorder_constraints",
"supports_multiple_targets",
"validate_control_variables",
Expand Down
183 changes: 101 additions & 82 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import warnings
from functools import wraps
from itertools import repeat
from typing import Callable
from typing import Sequence
from typing import Union

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -152,128 +154,145 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def prox_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for L0 regularization."""
threshold = np.sqrt(2 * lam)
return x * (np.abs(x) > threshold)


def prox_weighted_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for weighted l0 regularization."""
y = np.zeros(np.shape(x))
threshold = np.sqrt(2 * lam)
for i in range(lam.shape[0]):
for j in range(lam.shape[1]):
y[i, j] = x[i, j] * (np.abs(x[i, j]) > threshold[i, j])
return y


def prox_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - lam, 0)


def prox_weighted_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - lam, np.zeros(x.shape))


def prox_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * lam)
def validate_prox_and_reg_inputs(func):
@wraps(func)
def wrapper(x, regularization_weight):
# Example validation: check if both a and b are positive integers
if isinstance(regularization_weight, np.ndarray) and (
regularization_weight.shape != x.shape
and regularization_weight.shape != (1, 1)
):
raise ValueError(
f"Invalid shape for 'regularization_weight': {regularization_weight.shape}. Must be the same shape as x: {x.shape}."
)

# If validation passes, call the original function
return func(x, regularization_weight)

def prox_weighted_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * lam)
return wrapper


# TODO: replace code block with proper math block
def prox_cad(x: NDArray[np.float64], lam: NDArray[np.float64]):
def get_prox(
regularization: str,
) -> Callable[
[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], NDArray[np.float64]
]:
"""
Proximal operator for CAD regularization
.. code ::
Args:
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | 'l2' | 'weighted_l2'
prox_cad(z, a, b) =
0 if |z| < a
sign(z)(|z| - a) if a < |z| <= b
z if |z| > b
Returns:
proximal_operator: (x: np.array, regularization_weight: float | np,array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
and regularization weight factor which can be a scalar or an array of shape (n_targets, n_features),
and returns an array of shape (n_targets, n_features)
"""

Entries of :math:`x` smaller than a in magnitude are set to 0,
entries with magnitudes larger than b are untouched,
and entries in between have soft-thresholding applied.
def prox_l0(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for weighted l0 regularization."""
y = np.zeros(np.shape(x))
threshold = np.sqrt(2 * regularization_weight)
m, n = regularization_weight.shape
for i in range(m):
for j in range(n):
y[i, j] = x[i, j] * (np.abs(x[i, j]) > threshold[i, j])
return y

def prox_l1(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(
np.abs(x) - regularization_weight, np.zeros(x.shape)
)

For simplicity we set :math:`b = 5*a` in this implementation.
"""
lower_threshold = lam
upper_threshold = 5 * lam
return prox_l0(x, upper_threshold) + prox_l1(x, lower_threshold) * (
np.abs(x) < upper_threshold
)
def prox_l2(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

def prox_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

def get_prox(
regularization: str,
) -> Callable[[NDArray[np.float64], NDArray[np.float64]], NDArray[np.float64]]:
prox = {
"l0": prox_l0,
"weighted_l0": prox_weighted_l0,
"l1": prox_l1,
"weighted_l1": prox_weighted_l1,
"l2": prox_l2,
"weighted_l2": prox_weighted_l2,
"cad": prox_cad,
}
if regularization.lower() in prox:
return prox[regularization.lower()]
return validate_prox_and_reg_inputs(prox[regularization.lower()])
else:
raise NotImplementedError("{} has not been implemented".format(regularization))


def regularization_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
return lam * np.count_nonzero(x)


def regualization_weighted_l0(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam[np.nonzero(x)])


def regularization_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * np.abs(x))


def regualization_weighted_l1(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * np.abs(x))
def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], float]:
"""
Args:
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | 'l2' | 'weighted_l2'
Returns:
regularization_function: (x: np.array, regularization_weight: float) -> np.array
A function that takes an input x of shape (n_targets, n_features)
and regularization weight factor which can be a scalar or an array of shape (n_targets, n_features),
and returns a float
"""

def regularization_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * x**2)
def regularization_l0(x: NDArray[np.float64], regularization_weight: np.float64):
return regularization_weight * np.count_nonzero(x)

def regualization_weighted_l0(
x: NDArray[np.float64], regularization_weight: np.float64
):
return np.sum(regularization_weight[np.nonzero(x)])

def regualization_weighted_l2(x: NDArray[np.float64], lam: NDArray[np.float64]):
return np.sum(lam * x**2)
def regularization_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * np.abs(x))

def regualization_weighted_l1(
x: NDArray[np.float64], regularization_weight: np.float64
):
return np.sum(regularization_weight * np.abs(x))

def regularization_cad(x: NDArray[np.float64], lam: NDArray[np.float64]):
# dummy function
return 0
def regularization_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * x**2)

def regualization_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * x**2)

def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], NDArray[np.float64]], float]:
regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regualization_weighted_l0,
"l1": regularization_l1,
"weighted_l1": regualization_weighted_l1,
"l2": regularization_l2,
"weighted_l2": regualization_weighted_l2,
"cad": regularization_cad,
}
if regularization.lower() in regularization_fn:
return regularization_fn[regularization.lower()]
return validate_prox_and_reg_inputs(regularization_fn[regularization.lower()])
else:
raise NotImplementedError("{} has not been implemented".format(regularization))

Expand Down
8 changes: 0 additions & 8 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,6 @@ def test_prox_functions(data_derivative_1d, optimizer, thresholder):
check_is_fitted(model)


def test_cad_prox_function(data_derivative_1d):
x, x_dot = data_derivative_1d
x = x.reshape(-1, 1)
model = SR3(thresholder="cad")
model.fit(x, x_dot)
check_is_fitted(model)


@pytest.mark.parametrize("thresholder", ["weighted_l0", "weighted_l1"])
def test_weighted_prox_functions(data, thresholder):
x, x_dot = data
Expand Down

0 comments on commit 9ff815a

Please sign in to comment.