Skip to content

Commit

Permalink
feat: add parameter lasso_ratio to ElasticNetRegression (#237)
Browse files Browse the repository at this point in the history
Closes #166.

### Summary of Changes

Added parameter `lasso_ratio` to `ElasticNetRegression` and tests for
edge cases 0, 1, invalid and default.

---------

Co-authored-by: zzril <>
Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
robmeth and megalinter-bot authored Apr 21, 2023
1 parent 3aad07d commit 4a1a736
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/safeds/ml/classical/regression/_elastic_net_regression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

from sklearn.linear_model import ElasticNet as sk_ElasticNet
Expand All @@ -15,7 +16,22 @@
class ElasticNetRegression(Regressor):
"""Elastic net regression."""

def __init__(self) -> None:
def __init__(self, lasso_ratio: float = 0.5) -> None:
if lasso_ratio < 0 or lasso_ratio > 1:
raise ValueError("lasso_ratio must be between 0 and 1.")
elif lasso_ratio == 0:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
stacklevel=1,
)
elif lasso_ratio == 1:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
stacklevel=1,
)
self.lasso_ratio = lasso_ratio
self._wrapped_regressor: sk_ElasticNet | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -41,10 +57,10 @@ def fit(self, training_set: TaggedTable) -> ElasticNetRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_ElasticNet()
wrapped_regressor = sk_ElasticNet(l1_ratio=self.lasso_ratio)
fit(wrapped_regressor, training_set)

result = ElasticNetRegression()
result = ElasticNetRegression(self.lasso_ratio)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression._elastic_net_regression import ElasticNetRegression


def test_lasso_ratio_valid() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])
lasso_ratio = 0.3

elastic_net_regression = ElasticNetRegression(lasso_ratio).fit(tagged_training_set)
assert elastic_net_regression._wrapped_regressor is not None
assert elastic_net_regression._wrapped_regressor.l1_ratio == lasso_ratio


def test_lasso_ratio_invalid() -> None:
with pytest.raises(ValueError, match="lasso_ratio must be between 0 and 1."):
ElasticNetRegression(-1)


def test_lasso_ratio_zero() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
):
ElasticNetRegression(0)


def test_lasso_ratio_one() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
):
ElasticNetRegression(1)


# (Default parameter is tested in `test_regressor.py`.)

0 comments on commit 4a1a736

Please sign in to comment.