Skip to content

Allow to pass not all values to TSDataset.train_test_split method #191

Merged
merged 3 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Instruction notebook for custom model and transform creation ([#180](https://github.com/tinkoff-ai/etna-ts/pull/180))
- Add inverse_transform in *OutliersTransform ([#160](https://github.com/tinkoff-ai/etna-ts/pull/160))
- Examples for CatBoostModelMultiSegment and CatBoostModelPerSegment ([#181](https://github.com/tinkoff-ai/etna-ts/pull/181))
- Simplify TSDataset.train_test_split method by allowing to pass not all values ([#191](https://github.com/tinkoff-ai/etna-ts/pull/191))

### Changed
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
Expand Down
70 changes: 57 additions & 13 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,21 +466,60 @@ def to_dataset(df: pd.DataFrame) -> pd.DataFrame:
df.columns.names = ["segment", "feature"]
return df

def _find_all_borders(
self,
train_start: Optional[TTimestamp],
train_end: Optional[TTimestamp],
test_start: Optional[TTimestamp],
test_end: Optional[TTimestamp],
) -> Tuple[TTimestamp, TTimestamp, TTimestamp, TTimestamp]:
"""Find borders for train_test_split if some values wasn't specified."""
if train_start is None:
julia-shenshina marked this conversation as resolved.
Show resolved Hide resolved
train_start_defined = self.df.index.min()
else:
train_start_defined = train_start

if test_end is None:
test_end_defined = self.df.index.max()
else:
test_end_defined = test_end

if train_end is None and test_start is None:
raise ValueError("One of train_end or test_start should be defined")

if train_end is None:
test_start_idx = self.df.index.get_loc(test_start)
train_end_defined = self.df.index[test_start_idx - 1]
else:
train_end_defined = train_end

if test_start is None:
train_end_idx = self.df.index.get_loc(train_end)
test_start_defined = self.df.index[train_end_idx + 1]
else:
test_start_defined = test_start

return train_start_defined, train_end_defined, test_start_defined, test_end_defined

def train_test_split(
self, train_start: Optional[TTimestamp], train_end: TTimestamp, test_start: TTimestamp, test_end: TTimestamp
self,
train_start: Optional[TTimestamp],
train_end: Optional[TTimestamp],
test_start: Optional[TTimestamp],
test_end: Optional[TTimestamp],
) -> Tuple["TSDataset", "TSDataset"]:
"""Split given df with train-test timestamp indices.

Parameters
----------
train_start:
start timestamp of new train dataset
start timestamp of new train dataset, if None first timestamp is used
train_end:
end timestamp of new train dataset
end timestamp of new train dataset, if None previous to test_start timestamp is used
test_start:
start timestamp of new test dataset
start timestamp of new test dataset, if None next to train_end timestamp is used
test_end:
end timestamp of new test dataset
end timestamp of new test dataset, if None last timestamp is used

Returns
-------
Expand Down Expand Up @@ -517,17 +556,22 @@ def train_test_split(
2021-02-05 -5.10 0.40 2.15
2021-02-06 -6.22 0.92 0.97
"""
if pd.Timestamp(test_end) > self.df.index.max():
raise UserWarning(f"Max timestamp in df is {self.df.index.max()}.")
if pd.Timestamp(train_start) < self.df.index.min():
raise UserWarning(f"Min timestamp in df is {self.df.index.min()}.")
train_df = self.df[train_start:train_end][self.raw_df.columns] # type: ignore
train_raw_df = self.raw_df[train_start:train_end] # type: ignore
train_start_defined, train_end_defined, test_start_defined, test_end_defined = self._find_all_borders(
train_start, train_end, test_start, test_end
)

if pd.Timestamp(test_end_defined) > self.df.index.max():
warnings.warn(f"Max timestamp in df is {self.df.index.max()}.")
if pd.Timestamp(train_start_defined) < self.df.index.min():
warnings.warn(f"Min timestamp in df is {self.df.index.min()}.")

train_df = self.df[train_start_defined:train_end_defined][self.raw_df.columns] # type: ignore
train_raw_df = self.raw_df[train_start_defined:train_end_defined] # type: ignore
train = TSDataset(df=train_df, df_exog=self.df_exog, freq=self.freq)
train.raw_df = train_raw_df

test_df = self.df[test_start:test_end][self.raw_df.columns] # type: ignore
test_raw_df = self.raw_df[train_start:test_end] # type: ignore
test_df = self.df[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
test_raw_df = self.raw_df[train_start_defined:test_end_defined] # type: ignore
test = TSDataset(df=test_df, df_exog=self.df_exog, freq=self.freq)
test.raw_df = test_raw_df

Expand Down
56 changes: 51 additions & 5 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,67 @@ def test_categorical_after_call_to_pandas():


@pytest.mark.parametrize(
"train_start,train_end,test_start,test_end",
(("2021-02-03", "2021-06-20", "2021-06-21", "2021-07-01"), (None, "2021-06-20", "2021-06-21", "2021-07-01")),
"borders, true_borders",
(
(
("2021-02-01", "2021-06-20", "2021-06-21", "2021-07-01"),
("2021-02-01", "2021-06-20", "2021-06-21", "2021-07-01"),
),
(
("2021-02-03", "2021-06-20", "2021-06-22", "2021-07-01"),
("2021-02-03", "2021-06-20", "2021-06-22", "2021-07-01"),
),
(
("2021-02-01", "2021-06-20", "2021-06-21", "2021-06-28"),
("2021-02-01", "2021-06-20", "2021-06-21", "2021-06-28"),
),
(
("2021-02-01", "2021-06-20", "2021-06-23", "2021-07-01"),
("2021-02-01", "2021-06-20", "2021-06-23", "2021-07-01"),
),
((None, "2021-06-20", "2021-06-23", "2021-06-28"), ("2021-02-01", "2021-06-20", "2021-06-23", "2021-06-28")),
(("2021-02-03", "2021-06-20", "2021-06-23", None), ("2021-02-03", "2021-06-20", "2021-06-23", "2021-07-01")),
((None, "2021-06-20", "2021-06-23", None), ("2021-02-01", "2021-06-20", "2021-06-23", "2021-07-01")),
((None, "2021-06-20", None, None), ("2021-02-01", "2021-06-20", "2021-06-21", "2021-07-01")),
((None, None, "2021-06-21", None), ("2021-02-01", "2021-06-20", "2021-06-21", "2021-07-01")),
),
)
def test_train_test_split(train_start, train_end, test_start, test_end, tsdf_with_exog):
def test_train_test_split(borders, true_borders, tsdf_with_exog):
train_start, train_end, test_start, test_end = borders
train_start_true, train_end_true, test_start_true, test_end_true = true_borders
train, test = tsdf_with_exog.train_test_split(
train_start=train_start, train_end=train_end, test_start=test_start, test_end=test_end
)
assert isinstance(train, TSDataset)
assert isinstance(test, TSDataset)
assert (train.df == tsdf_with_exog.df[train_start:train_end]).all().all()
assert (train.df == tsdf_with_exog.df[train_start_true:train_end_true]).all().all()
assert (train.df_exog == tsdf_with_exog.df_exog).all().all()
assert (test.df == tsdf_with_exog.df[test_start:test_end]).all().all()
assert (test.df == tsdf_with_exog.df[test_start_true:test_end_true]).all().all()
assert (test.df_exog == tsdf_with_exog.df_exog).all().all()


@pytest.mark.parametrize(
"borders, match",
(
(("2021-01-01", "2021-06-20", "2021-06-21", "2021-07-01"), "Min timestamp in df is"),
(("2021-02-01", "2021-06-20", "2021-06-21", "2021-08-01"), "Max timestamp in df is"),
),
)
def test_train_test_split_warning(borders, match, tsdf_with_exog):
train_start, train_end, test_start, test_end = borders
with pytest.warns(UserWarning, match=match):
tsdf_with_exog.train_test_split(
train_start=train_start, train_end=train_end, test_start=test_start, test_end=test_end
)


def test_train_test_split_failed(tsdf_with_exog):
with pytest.raises(ValueError, match="train_end or test_start should be defined"):
tsdf_with_exog.train_test_split(
train_start="2021-02-03", train_end=None, test_start=None, test_end="2021-07-01"
)


def test_dataset_datetime_convertion():
classic_df = generate_ar_df(periods=30, start_time="2021-06-01", n_segments=2)
classic_df["timestamp"] = classic_df["timestamp"].astype(str)
Expand Down