Skip to content

Commit

Permalink
Add plot_features_relevance (#579)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Mar 10, 2022
1 parent 8b3c063 commit a0a84fd
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Create `BasePipeline`, add prediction intervals to all the pipelines, move parameter n_fold to forecast ([#578](https://github.com/tinkoff-ai/etna/pull/578))
- Add stl_plot ([#575](https://github.com/tinkoff-ai/etna/pull/575))
- Add plot_features_relevance ([#579](https://github.com/tinkoff-ai/etna/pull/579))
- Add community section to README.md ([#580](https://github.com/tinkoff-ai/etna/pull/580))
- Create `AbstaractPipeline` ([#573](https://github.com/tinkoff-ai/etna/pull/573))
-
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from etna.analysis.plotters import plot_backtest_interactive
from etna.analysis.plotters import plot_clusters
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
from etna.analysis.plotters import plot_forecast
from etna.analysis.plotters import plot_residuals
from etna.analysis.plotters import plot_time_series_with_change_points
Expand Down
2 changes: 2 additions & 0 deletions etna/analysis/feature_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from etna.analysis.feature_selection.mrmr import AGGREGATION_FN
from etna.analysis.feature_selection.mrmr import AggregationMode
from etna.analysis.feature_selection.mrmr import mrmr
6 changes: 3 additions & 3 deletions etna/analysis/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AggregationMode(str, Enum):
median = "median"


aggregation_fn = {
AGGREGATION_FN = {
AggregationMode.mean: np.mean,
AggregationMode.max: np.max,
AggregationMode.min: np.min,
Expand Down Expand Up @@ -58,8 +58,8 @@ def mrmr(
selected_features: List[str]
list of `top_k` selected regressors, sorted by their importance
"""
relevance_aggregation_fn = aggregation_fn[AggregationMode(relevance_aggregation_mode)]
redundancy_aggregation_fn = aggregation_fn[AggregationMode(redundancy_aggregation_mode)]
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
redundancy_aggregation_fn = AGGREGATION_FN[AggregationMode(redundancy_aggregation_mode)]

relevance = relevance_table.apply(relevance_aggregation_fn).fillna(0)

Expand Down
84 changes: 84 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
Expand All @@ -19,6 +20,9 @@
import seaborn as sns
from typing_extensions import Literal

from etna.analysis import RelevanceTable
from etna.analysis.feature_selection import AGGREGATION_FN
from etna.analysis.feature_selection import AggregationMode
from etna.transforms import Transform

if TYPE_CHECKING:
Expand Down Expand Up @@ -799,3 +803,83 @@ def plot_trend(
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()


def plot_feature_relevance(
ts: "TSDataset",
relevance_table: RelevanceTable,
normalized: bool = False,
relevance_aggregation_mode: Union[str, Literal["per-segment"]] = AggregationMode.mean,
relevance_params: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
"""
Plot relevance of the features.
The most important features are at the top, the least important are at the bottom.
Parameters
----------
ts:
TSDataset with timeseries data
relevance_table:
method to evaluate the feature relevance
normalized:
whether obtained relevances should be normalized to sum up to 1
relevance_aggregation_mode:
aggregation strategy for obtained feature relevance table;
all the strategies can be examined at `etna.analysis.feature_selection.AggregationMode`
relevance_params:
additional keyword arguments for `__call__` method of `RelevanceTable` instances
top_k:
number of best features to plot, if None plot all the features
segments:
segments to use
columns_num:
if `relevance_aggregation_mode="per-segment"` number of columns in subplots, otherwise the value is ignored
figsize:
size of the figure per subplot with one segment in inches
"""
if relevance_params is None:
relevance_params = {}
if not segments:
segments = sorted(ts.segments)

is_ascending = not relevance_table.greater_is_better
features = list(set(ts.columns.get_level_values("feature")) - {"target"})
relevance_df = relevance_table(df=ts[:, :, "target"], df_exog=ts[:, :, features], **relevance_params).loc[segments]

if relevance_aggregation_mode == "per-segment":
ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
relevance = relevance_df.loc[segment].sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(
f"Relevances on segment: {segment} of features: {na_relevance_features} can't be calculated."
)
relevance = relevance.dropna()[:top_k]
if normalized:
relevance = relevance / relevance.sum()
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax[i])
ax[i].set_title(f"Feature relevance: {segment}")

else:
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
relevance = relevance_df.apply(lambda x: relevance_aggregation_fn(x[~x.isna()])) # type: ignore
relevance = relevance.sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(f"Relevances of features: {na_relevance_features} can't be calculated.")
# if top_k == None, all the values are selected
relevance = relevance.dropna()[:top_k]
if normalized:
relevance = relevance / relevance.sum()
_, ax = plt.subplots(figsize=figsize, constrained_layout=True)
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax)
ax.set_title("Feature relevance") # type: ignore

0 comments on commit a0a84fd

Please sign in to comment.