Skip to content

Commit

Permalink
Add upadate sklearn, python and pytorch-forecasting versions (#445)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Gabdushev <[email protected]>
  • Loading branch information
alex-hse-repository and martins0n authored Jan 13, 2022
1 parent cf95197 commit 5b0f1f1
Show file tree
Hide file tree
Showing 7 changed files with 464 additions and 431 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change the way `LagTransform`, `DateFlagsTransform` and `TimeFlagsTransform` generate column names ([#421](https://github.com/tinkoff-ai/etna/pull/421))
- Clarify the behaviour of TimeSeriesImputerTransform in case of all NaN values ([#427](https://github.com/tinkoff-ai/etna/pull/427))
- Fixed bug in title in `sample_acf_plot` method ([#432](https://github.com/tinkoff-ai/etna/pull/432))
- Pytorch-forecasting and sklearn version update + some pytroch transform API changing ([#445](https://github.com/tinkoff-ai/etna/pull/445))

### Fixed
- Add relevance_params in GaleShapleyFeatureSelectionTransform ([#410](https://github.com/tinkoff-ai/etna/pull/410))
Expand Down
8 changes: 4 additions & 4 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def make_future(self, future_steps: int) -> "TSDataset":
segment segment_0 segment_1
feature regressor_1 regressor_2 target regressor_1 regressor_2 target
timestamp
2021-07-01 30 35 nan 70 75 nan
2021-07-02 31 36 nan 71 76 nan
2021-07-03 32 37 nan 72 77 nan
2021-07-04 33 38 nan 73 78 nan
2021-07-01 30 35 NaN 70 75 NaN
2021-07-02 31 36 NaN 71 76 NaN
2021-07-03 32 37 NaN 72 77 NaN
2021-07-04 33 38 NaN 73 78 NaN
"""
self._check_endings()
max_date_in_dataset = self.df.index.max()
Expand Down
6 changes: 5 additions & 1 deletion etna/ensembles/stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def _make_features(
features = pd.DataFrame()
if self.filtered_features_for_final_model is not None:
features_in_forecasts = [
set(forecast.columns.get_level_values("feature")).intersection(self.filtered_features_for_final_model)
list(
set(forecast.columns.get_level_values("feature")).intersection(
self.filtered_features_for_final_model
)
)
for forecast in forecasts
]
features = pd.concat(
Expand Down
12 changes: 5 additions & 7 deletions etna/transforms/nn/pytorch_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def __init__(
time_varying_unknown_categoricals: Optional[List[str]] = None,
time_varying_unknown_reals: Optional[List[str]] = None,
variable_groups: Optional[Dict[str, List[int]]] = None,
dropout_categoricals: Optional[List[str]] = None,
constant_fill_strategy: Optional[Dict[str, Union[str, float, int, bool]]] = None,
allow_missings: bool = True,
allow_missing_timesteps: bool = True,
lags: Optional[Dict[str, List[int]]] = None,
add_relative_time_idx: bool = True,
add_target_scales: bool = True,
Expand All @@ -63,7 +62,7 @@ def __init__(
Reference
---------
https://github.com/jdb78/pytorch-forecasting/blob/v0.8.5/pytorch_forecasting/data/timeseries.py#L117
https://github.com/jdb78/pytorch-forecasting/blob/v0.9.2/pytorch_forecasting/data/timeseries.py#L117
"""
super().__init__()
self.max_encoder_length = max_encoder_length
Expand All @@ -85,10 +84,9 @@ def __init__(
self.add_relative_time_idx = add_relative_time_idx
self.add_target_scales = add_target_scales
self.add_encoder_length = add_encoder_length
self.allow_missings = allow_missings
self.allow_missing_timesteps = allow_missing_timesteps
self.target_normalizer = target_normalizer
self.categorical_encoders = categorical_encoders if categorical_encoders else {}
self.dropout_categoricals = dropout_categoricals if dropout_categoricals else []
self.constant_fill_strategy = constant_fill_strategy if constant_fill_strategy else []
self.lags = lags if lags else {}
self.scalers = scalers if scalers else {}
Expand Down Expand Up @@ -144,14 +142,14 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform":
add_relative_time_idx=self.add_relative_time_idx,
add_target_scales=self.add_target_scales,
add_encoder_length=self.add_encoder_length,
allow_missings=self.allow_missings,
allow_missing_timesteps=self.allow_missing_timesteps,
target_normalizer=self.target_normalizer,
static_categoricals=self.static_categoricals,
min_prediction_idx=self.min_prediction_idx,
variable_groups=self.variable_groups,
dropout_categoricals=self.dropout_categoricals,
constant_fill_strategy=self.constant_fill_strategy,
lags=self.lags,
categorical_encoders=self.categorical_encoders,
scalers=self.scalers,
)

Expand Down
12 changes: 5 additions & 7 deletions examples/NN_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,8 @@
" time_varying_unknown_categoricals: List[str] = [],\n",
" time_varying_unknown_reals: List[str] = [],\n",
" variable_groups: Dict[str, List[int]] = {},\n",
" dropout_categoricals: List[str] = [],\n",
" constant_fill_strategy: Dict[str, Union[str, float, int, bool]] = {},\n",
" allow_missings: bool = True,\n",
" allow_missing_timesteps: bool = True,\n",
" lags: Dict[str, List[int]] = {},\n",
" add_relative_time_idx: bool = True,\n",
" add_target_scales: bool = True,\n",
Expand Down Expand Up @@ -919,8 +918,7 @@
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": false
}
},
"outputs": [
{
Expand Down Expand Up @@ -1567,9 +1565,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "local-venv",
"language": "python",
"name": "python3"
"name": "local-venv"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1581,7 +1579,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.8.6"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 5b0f1f1

Please sign in to comment.