Skip to content

Commit

Permalink
Update tsdataset.py (issue #741) (#744)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Gabdushev <[email protected]>
  • Loading branch information
mvakhmenin and martins0n authored Jun 14, 2022
1 parent e999ae2 commit 52307d2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Change format of holidays for holiday_plot ([#708](https://github.com/tinkoff-ai/etna/pull/708))
-
-
- Make TSDataset method to_dataset work with copy of the passed dataframe ([#741](https://github.com/tinkoff-ai/etna/pull/741))
-
- Make feature selection transforms return columns in inverse_transform([#688](https://github.com/tinkoff-ai/etna/issues/688))
-
Expand Down
17 changes: 9 additions & 8 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,16 +657,17 @@ def to_dataset(df: pd.DataFrame) -> pd.DataFrame:
2021-01-04 3 8
2021-01-05 4 9
"""
df["timestamp"] = pd.to_datetime(df["timestamp"])
df["segment"] = df["segment"].astype(str)
feature_columns = df.columns.tolist()
df_copy = df.copy(deep=True)
df_copy["timestamp"] = pd.to_datetime(df_copy["timestamp"])
df_copy["segment"] = df_copy["segment"].astype(str)
feature_columns = df_copy.columns.tolist()
feature_columns.remove("timestamp")
feature_columns.remove("segment")
df = df.pivot(index="timestamp", columns="segment")
df = df.reorder_levels([1, 0], axis=1)
df.columns.names = ["segment", "feature"]
df = df.sort_index(axis=1, level=(0, 1))
return df
df_copy = df_copy.pivot(index="timestamp", columns="segment")
df_copy = df_copy.reorder_levels([1, 0], axis=1)
df_copy.columns.names = ["segment", "feature"]
df_copy = df_copy.sort_index(axis=1, level=(0, 1))
return df_copy

def _find_all_borders(
self,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,11 @@ def test_inverse_transform_back_included_columns(ts_with_features, columns, retu
ts_with_features.fit_transform([transform])
ts_with_features.inverse_transform()
assert set(original_regressors) == set(ts_with_features.regressors)


def test_to_dataset_not_modify_dataframe():
timestamp = pd.date_range("2021-01-01", "2021-02-01")
df_original = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": 1})
df_copy = df_original.copy(deep=True)
df_mod = TSDataset.to_dataset(df_original)
pd.testing.assert_frame_equal(df_original, df_copy)

1 comment on commit 52307d2

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.