Skip to content

Add plot_features_relevance #579

Merged
merged 8 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Create `BasePipeline`, add prediction intervals to all the pipelines, move parameter n_fold to forecast ([#578](https://github.com/tinkoff-ai/etna/pull/578))
- Add stl_plot ([#575](https://github.com/tinkoff-ai/etna/pull/575))
- Add plot_features_relevance ([#579](https://github.com/tinkoff-ai/etna/pull/579))
- Add community section to README.md ([#580](https://github.com/tinkoff-ai/etna/pull/580))
- Create `AbstaractPipeline` ([#573](https://github.com/tinkoff-ai/etna/pull/573))
-
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from etna.analysis.plotters import plot_backtest_interactive
from etna.analysis.plotters import plot_clusters
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
from etna.analysis.plotters import plot_forecast
from etna.analysis.plotters import plot_residuals
from etna.analysis.plotters import plot_time_series_with_change_points
Expand Down
2 changes: 2 additions & 0 deletions etna/analysis/feature_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from etna.analysis.feature_selection.mrmr import AGGREGATION_FN
from etna.analysis.feature_selection.mrmr import AggregationMode
from etna.analysis.feature_selection.mrmr import mrmr
6 changes: 3 additions & 3 deletions etna/analysis/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AggregationMode(str, Enum):
median = "median"


aggregation_fn = {
AGGREGATION_FN = {
AggregationMode.mean: np.mean,
AggregationMode.max: np.max,
AggregationMode.min: np.min,
Expand Down Expand Up @@ -58,8 +58,8 @@ def mrmr(
selected_features: List[str]
list of `top_k` selected regressors, sorted by their importance
"""
relevance_aggregation_fn = aggregation_fn[AggregationMode(relevance_aggregation_mode)]
redundancy_aggregation_fn = aggregation_fn[AggregationMode(redundancy_aggregation_mode)]
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
redundancy_aggregation_fn = AGGREGATION_FN[AggregationMode(redundancy_aggregation_mode)]

relevance = relevance_table.apply(relevance_aggregation_fn).fillna(0)

Expand Down
84 changes: 84 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
Expand All @@ -19,6 +20,9 @@
import seaborn as sns
from typing_extensions import Literal

from etna.analysis import RelevanceTable
from etna.analysis.feature_selection import AGGREGATION_FN
from etna.analysis.feature_selection import AggregationMode
from etna.transforms import Transform

if TYPE_CHECKING:
Expand Down Expand Up @@ -799,3 +803,83 @@ def plot_trend(
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()


def plot_feature_relevance(
ts: "TSDataset",
relevance_table: RelevanceTable,
normalized: bool = False,
relevance_aggregation_mode: Union[str, Literal["per-segment"]] = AggregationMode.mean,
relevance_params: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
"""
Plot relevance of the features.
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

The most important features are at the top, the least important are at the bottom.

Parameters
----------
ts:
TSDataset with timeseries data
relevance_table:
method to evaluate the feature relevance
normalized:
whether obtained relevances should be normalized to sum up to 1
relevance_aggregation_mode:
aggregation strategy for obtained feature relevance table;
all the strategies can be examined at `etna.analysis.feature_selection.AggregationMode`
relevance_params:
additional keyword arguments for `__call__` method of `RelevanceTable` instances
top_k:
number of best features to plot, if None plot all the features
segments:
segments to use
columns_num:
if `relevance_aggregation_mode="per-segment"` number of columns in subplots, otherwise the value is ignored
figsize:
size of the figure per subplot with one segment in inches
"""
if relevance_params is None:
relevance_params = {}
if not segments:
segments = sorted(ts.segments)

is_ascending = not relevance_table.greater_is_better
features = list(set(ts.columns.get_level_values("feature")) - {"target"})
relevance_df = relevance_table(df=ts[:, :, "target"], df_exog=ts[:, :, features], **relevance_params).loc[segments]

if relevance_aggregation_mode == "per-segment":
ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
relevance = relevance_df.loc[segment].sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(
f"Relevances on segment: {segment} of features: {na_relevance_features} can't be calculated."
)
relevance = relevance.dropna()[:top_k]
if normalized:
relevance = relevance / relevance.sum()
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax[i])
ax[i].set_title(f"Feature relevance: {segment}")

else:
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
relevance = relevance_df.apply(lambda x: relevance_aggregation_fn(x[~x.isna()])) # type: ignore
relevance = relevance.sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(f"Relevances of features: {na_relevance_features} can't be calculated.")
# if top_k == None, all the values are selected
relevance = relevance.dropna()[:top_k]
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
if normalized:
relevance = relevance / relevance.sum()
_, ax = plt.subplots(figsize=figsize, constrained_layout=True)
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax)
ax.set_title("Feature relevance") # type: ignore