diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 00655c677..3dbeae04c 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -66,12 +66,14 @@ def __init__(self, df: pd.DataFrame, freq: str, df_exog: Optional[pd.DataFrame] def transform(self, transforms: Iterable[Transform]): """Apply given transform to the data.""" + self._check_endings() self.transforms = transforms for transform in self.transforms: self.df = transform.transform(self.df) def fit_transform(self, transforms: Iterable[Transform]): """Fit and apply given transforms to the data.""" + self._check_endings() self.transforms = transforms for transform in self.transforms: self.df = transform.fit_transform(self.df) @@ -139,6 +141,13 @@ def _merge_exog(self, df): df = pd.merge(df, self.df_exog, left_index=True, right_index=True).sort_index(axis=1) return df + def _check_endings(self): + """Check that all targets ends at the same timestamp.""" + max_index = self.df.index.max() + for segment in self.df.columns.get_level_values("segment"): + if np.isnan(self.df.loc[max_index, pd.IndexSlice[segment, "target"]]): + raise ValueError(f"All segments should end at the same timestamp") + def inverse_transform(self): """Apply inverse transform method of transforms to the data. Applied in revered order. diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 14f82903d..a259f6b1b 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -32,6 +32,28 @@ def tsdf_with_exog() -> TSDataset: return ts +def test_same_ending_error_raise(): + timestamp = pd.date_range("2021-01-01", "2021-02-01") + df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"}) + df2 = pd.DataFrame({"timestamp": timestamp[:-5], "target": 12, "segment": "2"}) + df = pd.concat([df1, df2], ignore_index=True) + df = TSDataset.to_dataset(df) + ts = TSDataset(df=df, freq="D") + + with pytest.raises(ValueError): + ts.fit_transform([]) + + +def test_same_ending_error_pass(): + timestamp = pd.date_range("2021-01-01", "2021-02-01") + df1 = pd.DataFrame({"timestamp": timestamp, "target": 11, "segment": "1"}) + df2 = pd.DataFrame({"timestamp": timestamp, "target": 12, "segment": "2"}) + df = pd.concat([df1, df2], ignore_index=True) + df = TSDataset.to_dataset(df) + ts = TSDataset(df=df, freq="D") + ts.fit_transform([]) + + def test_categorical_after_call_to_pandas(): classic_df = generate_ar_df(periods=30, start_time="2021-06-01", n_segments=2) classic_df["categorical_column"] = [0] * 30 + [1] * 30