Skip to content

Commit

Permalink
tst: Test local trapping
Browse files Browse the repository at this point in the history
Not necessarily the best test, but coverage should at least find any shape
errors that arise.  Also, remove tests for different regularizers from trapping,
now that that regularization is fully abstracted to superclass, with exception
of reg == 0 vs reg != 0
  • Loading branch information
Jacob-Stevens-Haas committed Sep 18, 2024
1 parent fa65917 commit 7893d43
Showing 1 changed file with 16 additions and 38 deletions.
54 changes: 16 additions & 38 deletions test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,60 +508,38 @@ def test_stable_linear_sr3_linear_library():
assert np.allclose(opt.coef_.flatten(), 0.0)


@pytest.mark.parametrize(
"params",
[
dict(regularizer="l1", reg_weight_lam=0, _include_bias=True),
dict(regularizer="l1", reg_weight_lam=1e-5, _include_bias=True),
dict(
regularizer="weighted_l1",
reg_weight_lam=np.zeros((1, 2)),
eta=1e5,
alpha_m=1e4,
alpha_A=1e5,
_include_bias=False,
),
dict(
regularizer="weighted_l1",
reg_weight_lam=1e-5 * np.ones((1, 2)),
_include_bias=False,
),
dict(regularizer="l2", reg_weight_lam=0, _include_bias=True),
dict(regularizer="l2", reg_weight_lam=1e-5, _include_bias=True),
dict(
regularizer="weighted_l2",
reg_weight_lam=np.zeros((1, 2)),
_include_bias=False,
),
dict(
regularizer="weighted_l2",
reg_weight_lam=1e-5 * np.ones((1, 2)),
_include_bias=False,
),
],
)
def test_trapping_sr3_quadratic_library(params):
@pytest.mark.parametrize("bias", (True, False))
@pytest.mark.parametrize("method", ("global", "local"))
@pytest.mark.parametrize("reg_weight", (0.0, 1e-1))
def test_trapping_sr3_quadratic_library(bias, method, reg_weight):
t = np.arange(0, 1, 0.1)
x = np.exp(-t).reshape((-1, 1))
x_dot = -x
features = np.hstack([x, x**2])
if params.get("_include_bias"):
if bias:
features = np.hstack([np.ones_like(x), features])

opt = TrappingSR3(_n_tgts=1, **params)
params = {
"_n_tgts": 1,
"_include_bias": bias,
"method": method,
"reg_weight_lam": reg_weight,
}

opt = TrappingSR3(**params)
opt.fit(features, x_dot)
check_is_fitted(opt)

# Rerun with identity constraints
r = x.shape[1]
N = 2 + params.get("_include_bias", 0)
N = 2 + bias
params["constraint_rhs"] = np.zeros(r * N)
params["constraint_lhs"] = np.eye(r * N, r * N)

opt = TrappingSR3(_n_tgts=1, **params)
opt = TrappingSR3(**params)
opt.fit(features, x_dot)
check_is_fitted(opt)
# check is solve was infeasible first
# check if solve was infeasible first
if not np.allclose(opt.m_history_[-1], opt.m_history_[0]):
assert np.allclose((opt.coef_.flatten())[0], 0.0, atol=1e-5)

Expand Down

0 comments on commit 7893d43

Please sign in to comment.