diff --git a/CHANGELOG.md b/CHANGELOG.md index 57c22e2ee..7b25ede12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- RelevanceTable returns rank ([#268](https://github.com/tinkoff-ai/etna-ts/pull/268/)) + +### Changed + +### Fixed ## [1.3.1] - 2021-11-12 ### Changed diff --git a/etna/analysis/feature_relevance/relevance.py b/etna/analysis/feature_relevance/relevance.py index eceaa9421..61844a39a 100644 --- a/etna/analysis/feature_relevance/relevance.py +++ b/etna/analysis/feature_relevance/relevance.py @@ -2,6 +2,7 @@ from abc import abstractmethod import pandas as pd +import scipy.stats from etna.analysis.feature_relevance.relevance_table import get_model_relevance_table from etna.analysis.feature_relevance.relevance_table import get_statistics_relevance_table @@ -21,8 +22,15 @@ def __init__(self, greater_is_better: bool): """ self.greater_is_better = greater_is_better + def _get_ranks(self, table: pd.DataFrame) -> pd.DataFrame: + """Compute rank relevance table from relevance table.""" + if self.greater_is_better: + table *= -1 + rank_table = pd.DataFrame(scipy.stats.rankdata(table, axis=1), columns=table.columns, index=table.index) + return rank_table.astype(int) + @abstractmethod - def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.DataFrame: + def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, return_ranks: bool = False, **kwargs) -> pd.DataFrame: """Compute relevance table. For each series in df compute relevance of corresponding series in df_exog. @@ -32,6 +40,8 @@ def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.Data dataframe with series that will be used as target df_exog: dataframe with series to compute relevance for df + return_ranks: + if False return relevance values else return ranks of relevance values Returns ------- @@ -47,9 +57,11 @@ class StatisticsRelevanceTable(RelevanceTable): def __init__(self): super().__init__(greater_is_better=False) - def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.DataFrame: + def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, return_ranks: bool = False, **kwargs) -> pd.DataFrame: """Compute feature relevance table with etna.analysis.get_statistics_relevance_table method.""" table = get_statistics_relevance_table(df=df, df_exog=df_exog) + if return_ranks: + return self._get_ranks(table) return table @@ -59,7 +71,9 @@ class ModelRelevanceTable(RelevanceTable): def __init__(self): super().__init__(greater_is_better=True) - def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, **kwargs) -> pd.DataFrame: + def __call__(self, df: pd.DataFrame, df_exog: pd.DataFrame, return_ranks: bool = False, **kwargs) -> pd.DataFrame: """Compute feature relevance table with etna.analysis.get_model_relevance_table method.""" table = get_model_relevance_table(df=df, df_exog=df_exog, **kwargs) + if return_ranks: + return self._get_ranks(table) return table diff --git a/tests/test_analysis/test_feature_relevance/test_relevance.py b/tests/test_analysis/test_feature_relevance/test_relevance.py index 0fa11b9cc..100be660d 100644 --- a/tests/test_analysis/test_feature_relevance/test_relevance.py +++ b/tests/test_analysis/test_feature_relevance/test_relevance.py @@ -1,3 +1,4 @@ +import pytest from sklearn.tree import DecisionTreeRegressor from etna.analysis.feature_relevance import ModelRelevanceTable @@ -8,11 +9,26 @@ def test_statistics_relevance_table(simple_df_relevance): rt = StatisticsRelevanceTable() assert not rt.greater_is_better df, df_exog = simple_df_relevance - assert rt(df=df, df_exog=df_exog).shape == (2, 2) + assert rt(df=df, df_exog=df_exog, return_ranks=False).shape == (2, 2) def test_model_relevance_table(simple_df_relevance): rt = ModelRelevanceTable() assert rt.greater_is_better df, df_exog = simple_df_relevance - assert rt(df=df, df_exog=df_exog, model=DecisionTreeRegressor()).shape == (2, 2) + assert rt(df=df, df_exog=df_exog, return_ranks=False, model=DecisionTreeRegressor()).shape == (2, 2) + + +@pytest.mark.parametrize( + "greater_is_better,answer", + ((True, [1, 2, 2, 1]), (False, [2, 1, 1, 2])), +) +def test_relevance_table_ranks(greater_is_better, answer, simple_df_relevance): + rt = ModelRelevanceTable() + rt.greater_is_better = greater_is_better + df, df_exog = simple_df_relevance + table = rt(df=df, df_exog=df_exog, return_ranks=True, model=DecisionTreeRegressor()) + assert table["regressor_1"]["1"] == answer[0] + assert table["regressor_2"]["1"] == answer[1] + assert table["regressor_1"]["2"] == answer[2] + assert table["regressor_2"]["2"] == answer[3]