Skip to content

fix Gale-Shapley bug #410

Merged
merged 9 commits into from
Dec 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Fixed
- Add relevance_params in GaleShapleyFeatureSelectionTransform ([#410](https://github.com/tinkoff-ai/etna/pull/410))

## [1.5.0] - 2021-12-24
### Added
- Holiday Transform ([#359](https://github.com/tinkoff-ai/etna/pull/359))
Expand Down
7 changes: 5 additions & 2 deletions etna/transforms/gale_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __call__(self) -> Dict[str, str]:
class GaleShapleyFeatureSelectionTransform(BaseFeatureSelectionTransform):
"""GaleShapleyFeatureSelectionTransform provides feature filtering with Gale-Shapley matching algo according to relevance table."""

def __init__(self, relevance_table: RelevanceTable, top_k: int, use_rank: bool = False):
def __init__(self, relevance_table: RelevanceTable, top_k: int, use_rank: bool = False, **relevance_params):
"""Init GaleShapleyFeatureSelectionTransform.

Parameters
Expand All @@ -238,12 +238,15 @@ class to build relevance table
self.top_k = top_k
self.use_rank = use_rank
self.greater_is_better = False if use_rank else relevance_table.greater_is_better
self.relevance_params = relevance_params

def _compute_relevance_table(self, df: pd.DataFrame, regressors: List[str]) -> pd.DataFrame:
"""Compute relevance table with given data."""
targets_df = df.loc[:, pd.IndexSlice[:, "target"]]
regressors_df = df.loc[:, pd.IndexSlice[:, regressors]]
table = self.relevance_table(df=targets_df, df_exog=regressors_df, return_ranks=self.use_rank)
table = self.relevance_table(
df=targets_df, df_exog=regressors_df, return_ranks=self.use_rank, **self.relevance_params
)
return table

@staticmethod
Expand Down
12 changes: 12 additions & 0 deletions tests/test_transforms/test_gale_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.ensemble import RandomForestRegressor

from etna.analysis.feature_relevance import ModelRelevanceTable
from etna.analysis.feature_relevance import StatisticsRelevanceTable
from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
Expand Down Expand Up @@ -585,6 +587,16 @@ def test_gale_shapley_transform_fit_transform(ts_with_large_regressors_number: T
}


@pytest.mark.parametrize("use_rank", (True, False))
@pytest.mark.parametrize("top_k", (2, 3, 5, 6, 7))
def test_gale_shapley_transform_fit_model_based(ts_with_large_regressors_number: TSDataset, top_k: int, use_rank: bool):
df = ts_with_large_regressors_number.df
transform = GaleShapleyFeatureSelectionTransform(
relevance_table=ModelRelevanceTable(), top_k=top_k, use_rank=use_rank, model=RandomForestRegressor()
)
transform.fit(df=df)


@pytest.mark.xfail
def test_fit_transform_with_nans(regressor_exog_weekend):
transform = GaleShapleyFeatureSelectionTransform(
Expand Down