Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored DatetimeInfoExtractor to improve readability/condense #251

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Changed
- Added test_BaseTwoColumnTransformer base class for columns that require a list of two columns for input
- Added BaseDropOriginalMixin to mixin transformers to handle validation and method of dropping original features, also added appropriate test classes.
- Refactored MeanImputer tests in new format `#250 <https://github.com/lvgig/tubular/pull/250>`_
- Refactored DatetimeInfoExtractor to condense and improve readability


Removed
Expand Down
26 changes: 8 additions & 18 deletions tests/dates/test_DateTimeInfoExtractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -63,7 +61,7 @@ def test_error_when_invalid_include_option(self):
"""Test that an exception is raised when include contains incorrect values."""
with pytest.raises(
ValueError,
match=r'elements in include should be in \["timeofday", "timeofmonth", "timeofyear", "dayofweek"\]',
match=r"DatetimeInfoExtractor: elements in include should be in \['timeofday', 'timeofmonth', 'timeofyear', 'dayofweek'\]",
):
DatetimeInfoExtractor(
columns=["a"],
Expand Down Expand Up @@ -138,27 +136,19 @@ def test_error_when_datetime_mapping_key_not_in_include(
[
(
{"timeofday": {"mapped": range(23)}},
re.escape(
"timeofday mapping dictionary should contain mapping for all hours between 0-23. {23} are missing",
),
r"DatetimeInfoExtractor: timeofday mapping dictionary should contain mapping for all values between 0-23. \{23\} are missing",
),
(
{"timeofmonth": {"mapped": range(1, 31)}},
re.escape(
"timeofmonth mapping dictionary should contain mapping for all days between 1-31. {31} are missing",
),
r"DatetimeInfoExtractor: timeofmonth mapping dictionary should contain mapping for all values between 1-31. \{31\} are missing",
),
(
{"timeofyear": {"mapped": range(1, 12)}},
re.escape(
"timeofyear mapping dictionary should contain mapping for all months between 1-12. {12} are missing",
),
r"DatetimeInfoExtractor: timeofyear mapping dictionary should contain mapping for all values between 1-12. \{12\} are missing",
),
(
{"dayofweek": {"mapped": range(6)}},
re.escape(
"dayofweek mapping dictionary should contain mapping for all days between 0-6. {6} are missing",
),
r"DatetimeInfoExtractor: dayofweek mapping dictionary should contain mapping for all values between 0-6. \{6\} are missing",
),
],
)
Expand All @@ -177,8 +167,8 @@ class TestMapValues:
def test_incorrect_type_input(self, incorrect_type_input, timeofday_extractor):
"""Test that an error is raised if input is the wrong type."""
with pytest.raises(
TypeError,
match="DatetimeInfoExtractor: value should be float or int",
ValueError,
match="DatetimeInfoExtractor: value for timeofday mapping in self._map_values should be an integer value in 0-23",
):
timeofday_extractor._map_values(incorrect_type_input, "timeofday")

Expand All @@ -191,7 +181,7 @@ def test_out_of_bounds_or_fractional_input(
"""Test that an error is raised when value is outside of 0-23 range."""
with pytest.raises(
ValueError,
match="DatetimeInfoExtractor: value for timeofday mapping in self._map_values should be an integer value in 0-23",
match="DatetimeInfoExtractor: value for timeofday mapping in self._map_values should be an integer value in 0-23",
):
timeofday_extractor._map_values(incorrect_size_input, "timeofday")

Expand Down
251 changes: 103 additions & 148 deletions tubular/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,56 @@ class DatetimeInfoExtractor(BaseDateTransformer):

"""

TIME_OF_DAY = "timeofday"
TIME_OF_MONTH = "timeofmonth"
TIME_OF_YEAR = "timeofyear"
DAY_OF_WEEK = "dayofweek"

DEFAULT_MAPPINGS = {
TIME_OF_DAY: {
"night": range(6), # Midnight - 6am
"morning": range(6, 12), # 6am - Noon
"afternoon": range(12, 18), # Noon - 6pm
"evening": range(18, 24), # 6pm - Midnight
},
TIME_OF_MONTH: {
"start": range(1, 11),
"middle": range(11, 21),
"end": range(21, 32),
},
TIME_OF_YEAR: {
"spring": range(3, 6), # Mar, Apr, May
"summer": range(6, 9), # Jun, Jul, Aug
"autumn": range(9, 12), # Sep, Oct, Nov
"winter": [12, 1, 2], # Dec, Jan, Feb
},
DAY_OF_WEEK: {
"monday": [0],
"tuesday": [1],
"wednesday": [2],
"thursday": [3],
"friday": [4],
"saturday": [5],
"sunday": [6],
},
}

INCLUDE_OPTIONS = list(DEFAULT_MAPPINGS.keys())

RANGE_TO_MAP = {
TIME_OF_DAY: set(range(24)),
TIME_OF_MONTH: set(range(1, 32)),
TIME_OF_YEAR: set(range(1, 13)),
DAY_OF_WEEK: set(range(7)),
}

DATETIME_ATTR = {
TIME_OF_DAY: "hour",
TIME_OF_MONTH: "day",
TIME_OF_YEAR: "month",
DAY_OF_WEEK: "weekday",
}

def __init__(
self,
columns: str | list[str],
Expand All @@ -880,7 +930,7 @@ def __init__(
**kwargs: dict[str, bool],
) -> None:
if include is None:
include = ["timeofday", "timeofmonth", "timeofyear", "dayofweek"]
include = self.INCLUDE_OPTIONS
else:
if type(include) is not list:
msg = f"{self.classname()}: include should be List"
Expand All @@ -893,8 +943,6 @@ def __init__(
msg = f"{self.classname()}: datetime_mappings should be Dict"
raise TypeError(msg)

# note, this has highlighted that new_column_name func might be worth pulling out into a mixin
# but will leave for followup PR
super().__init__(
columns=columns,
drop_original=drop_original,
Expand All @@ -903,13 +951,8 @@ def __init__(
)

for var in include:
if var not in [
"timeofday",
"timeofmonth",
"timeofyear",
"dayofweek",
]:
msg = f'{self.classname()}: elements in include should be in ["timeofday", "timeofmonth", "timeofyear", "dayofweek"]'
if var not in self.INCLUDE_OPTIONS:
msg = f"{self.classname()}: elements in include should be in {self.INCLUDE_OPTIONS}"
raise ValueError(msg)

if datetime_mappings != {}:
Expand All @@ -925,103 +968,50 @@ def __init__(
self.datetime_mappings = datetime_mappings
self.mappings_provided = self.datetime_mappings.keys()

# Select correct mapping either from default or user input

if ("timeofday" in include) and ("timeofday" in self.mappings_provided):
timeofday_mapping = self.datetime_mappings["timeofday"]
elif "timeofday" in include: # Choose default mapping
timeofday_mapping = {
"night": range(6), # Midnight - 6am
"morning": range(6, 12), # 6am - Noon
"afternoon": range(12, 18), # Noon - 6pm
"evening": range(18, 24), # 6pm - Midnight
}

if ("timeofmonth" in include) and ("timeofmonth" in self.mappings_provided):
timeofmonth_mapping = self.datetime_mappings["timeofmonth"]
elif "timeofmonth" in include: # Choose default mapping
timeofmonth_mapping = {
"start": range(11),
"middle": range(11, 21),
"end": range(21, 32),
}

if ("timeofyear" in include) and ("timeofyear" in self.mappings_provided):
timeofyear_mapping = self.datetime_mappings["timeofyear"]
elif "timeofyear" in include: # Choose default mapping
timeofyear_mapping = {
"spring": range(3, 6), # Mar, Apr, May
"summer": range(6, 9), # Jun, Jul, Aug
"autumn": range(9, 12), # Sep, Oct, Nov
"winter": [12, 1, 2], # Dec, Jan, Feb
}

if ("dayofweek" in include) and ("dayofweek" in self.mappings_provided):
dayofweek_mapping = self.datetime_mappings["dayofweek"]
elif "dayofweek" in include: # Choose default mapping
dayofweek_mapping = {
"monday": [0],
"tuesday": [1],
"wednesday": [2],
"thursday": [3],
"friday": [4],
"saturday": [5],
"sunday": [6],
}

# Invert dictionaries for quicker lookup

if "timeofday" in include:
self.timeofday_mapping = {
vi: k for k, v in timeofday_mapping.items() for vi in v
}
if set(self.timeofday_mapping.keys()) != set(range(24)):
msg = f"{self.classname()}: timeofday mapping dictionary should contain mapping for all hours between 0-23. {set(range(24)) - set(self.timeofday_mapping.keys())} are missing"
raise ValueError(msg)

# Check if all hours in dictionary
else:
self.timeofday_mapping = {}

if "timeofmonth" in include:
self.timeofmonth_mapping = {
vi: k for k, v in timeofmonth_mapping.items() for vi in v
}
if set(self.timeofmonth_mapping.keys()) != set(range(32)):
msg = f"{self.classname()}: timeofmonth mapping dictionary should contain mapping for all days between 1-31. {set(range(1, 32)) - set(self.timeofmonth_mapping.keys())} are missing"
raise ValueError(msg)
else:
self.timeofmonth_mapping = {}

if "timeofyear" in include:
self.timeofyear_mapping = {
vi: k for k, v in timeofyear_mapping.items() for vi in v
}
if set(self.timeofyear_mapping.keys()) != set(range(1, 13)):
msg = f"{self.classname()}: timeofyear mapping dictionary should contain mapping for all months between 1-12. {set(range(1, 13)) - set(self.timeofyear_mapping.keys())} are missing"
raise ValueError(msg)
self._process_provided_mappings()

else:
self.timeofyear_mapping = {}
def _process_provided_mappings(self) -> None:
"""Method to process user provided mappings. Sets mappings attribute, then transforms to set a second
inverted_datetime_mappings attribute. Validates against RANGE_TO_MAP.

if "dayofweek" in include:
self.dayofweek_mapping = {
vi: k for k, v in dayofweek_mapping.items() for vi in v
}
if set(self.dayofweek_mapping.keys()) != set(range(7)):
msg = f"{self.classname()}: dayofweek mapping dictionary should contain mapping for all days between 0-6. {set(range(7)) - set(self.dayofweek_mapping.keys())} are missing"
raise ValueError(msg)
Returns
-------
None
"""

else:
self.dayofweek_mapping = {}
self.mappings = {}
self.inverted_datetime_mappings = {}
for include_option in self.INCLUDE_OPTIONS:
if (include_option in self.include) and (
include_option in self.mappings_provided
):
self.mappings[include_option] = self.datetime_mappings[include_option]
else:
self.mappings[include_option] = self.DEFAULT_MAPPINGS[include_option]

# Invert dictionaries for quicker lookup
if include_option in self.include:
self.inverted_datetime_mappings[include_option] = {
vi: k for k, v in self.mappings[include_option].items() for vi in v
}

# check provided mappings fit required format
if (
set(self.inverted_datetime_mappings[include_option].keys())
!= self.RANGE_TO_MAP[include_option]
):
msg = f"{self.classname()}: {include_option} mapping dictionary should contain mapping for all values between {min(self.RANGE_TO_MAP[include_option])}-{max(self.RANGE_TO_MAP[include_option])}. {self.RANGE_TO_MAP[include_option] - set(self.inverted_datetime_mappings[include_option].keys())} are missing"
raise ValueError(msg)
else:
self.inverted_datetime_mappings[include_option] = {}

def _map_values(self, value: float, interval: str) -> str:
def _map_values(self, value: float, include_option: str) -> str:
"""Method to apply mappings for a specified interval ("timeofday", "timeofmonth", "timeofyear" or "dayofweek")
from corresponding mapping attribute to a single value.

Parameters
----------
interval : str
include_option : str
the time period to map "timeofday", "timeofmonth", "timeofyear" or "dayofweek"

value : float or int
Expand All @@ -1033,37 +1023,17 @@ def _map_values(self, value: float, interval: str) -> str:
str : str
Mapped value
"""
if type(value) is not float and type(value) is not int:
msg = f"{self.classname()}: value should be float or int"
raise TypeError(msg)

errors = {
"timeofday": "0-23",
"dayofweek": "0-6",
"timeofmonth": "1-31",
"timeofyear": "1-12",
}
ranges = {
"timeofday": (0, 24, 1),
"dayofweek": (0, 7, 1),
"timeofmonth": (1, 32, 1),
"timeofyear": (1, 13, 1),
}
mappings = {
"timeofday": self.timeofday_mapping,
"dayofweek": self.dayofweek_mapping,
"timeofmonth": self.timeofmonth_mapping,
"timeofyear": self.timeofyear_mapping,
}

if (not np.isnan(value)) and (value not in np.arange(*ranges[interval])):
msg = f"{self.classname()}: value for {interval} mapping in self._map_values should be an integer value in {errors[interval]}"
raise ValueError(msg)
if isinstance(value, float):
if np.isnan(value):
return np.nan
if value.is_integer():
value = int(value)

if np.isnan(value):
return np.nan
if isinstance(value, int) and value in self.RANGE_TO_MAP[include_option]:
return self.inverted_datetime_mappings[include_option][value]

return mappings[interval][value]
msg = f"{self.classname()}: value for {include_option} mapping in self._map_values should be an integer value in {min(self.RANGE_TO_MAP[include_option])}-{max(self.RANGE_TO_MAP[include_option])}"
davidhopkinson26 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""Transform - Extracts new features from datetime variables.
Expand All @@ -1081,28 +1051,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X = super().transform(X, datetime_only=True)

for col in self.columns:
if "timeofday" in self.include:
X[col + "_timeofday"] = X[col].dt.hour.apply(
self._map_values,
interval="timeofday",
)

if "timeofmonth" in self.include:
X[col + "_timeofmonth"] = X[col].dt.day.apply(
self._map_values,
interval="timeofmonth",
)

if "timeofyear" in self.include:
X[col + "_timeofyear"] = X[col].dt.month.apply(
self._map_values,
interval="timeofyear",
)

if "dayofweek" in self.include:
X[col + "_dayofweek"] = X[col].dt.weekday.apply(
for include_option in self.include:
X[col + "_" + include_option] = getattr(
X[col].dt,
self.DATETIME_ATTR[include_option],
).apply(
self._map_values,
interval="dayofweek",
include_option=include_option,
)

if self.drop_original:
Expand Down
Loading