Skip to content

Commit

Permalink
Fix PR 3
Browse files Browse the repository at this point in the history
  • Loading branch information
malodetz committed Aug 7, 2023
1 parent 4159e50 commit 3dfac4b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
-
- Add modes `binary` and `categories` to `HolidayTransform` ([#763](https://github.com/tinkoff-ai/etna/pull/763))
- Add modes `binary` and `category` to `HolidayTransform` ([#763](https://github.com/tinkoff-ai/etna/pull/763))
- Add sorting by timestamp before the fit in `CatBoostPerSegmentModel` and `CatBoostMultiSegmentModel` ([#1337](https://github.com/tinkoff-ai/etna/pull/1337))
- Speed up metrics computation by optimizing segment validation, forbid NaNs during metrics computation ([#1338](https://github.com/tinkoff-ai/etna/pull/1338))
- Unify errors, warnings and checks in models ([#1312](https://github.com/tinkoff-ai/etna/pull/1312))
Expand Down
12 changes: 8 additions & 4 deletions etna/transforms/timestamp/holiday.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def _missing_(cls, value):

class HolidayTransform(IrreversibleTransform, FutureMixin):
"""
HolidayTransform generates series that indicates holidays in given dataframe.
Can either show holiday presence or their names (NO_HOLIDAY indicates absence).
HolidayTransform generates series that indicates holidays in given dataset.
In ``binary`` mode shows the presence of holiday in that day. In ``category`` mode shows the name of the holiday
with value "NO_HOLIDAY" reserved for days without holidays.
"""

NO_HOLIDAY: str = "NO_HOLIDAY"
_no_holiday_name: str = "NO_HOLIDAY"

def __init__(self, iso_code: str = "RUS", mode: str = "binary", out_column: Optional[str] = None):
"""
Expand Down Expand Up @@ -90,7 +92,9 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame:

out_column = self._get_column_name()
if self._mode is HolidayTransformMode.category:
encoded_matrix = np.array([self.holidays[x] if x in self.holidays else self.NO_HOLIDAY for x in df.index])
encoded_matrix = np.array(
[self.holidays[x] if x in self.holidays else self._no_holiday_name for x in df.index]
)
else:
encoded_matrix = np.array([int(x in self.holidays) for x in df.index])
encoded_matrix = encoded_matrix.reshape(-1, 1).repeat(len(cols), axis=1)
Expand Down

0 comments on commit 3dfac4b

Please sign in to comment.