diff --git a/etna/datasets/hierarchical_structure.py b/etna/datasets/hierarchical_structure.py index c142d0a19..89f665cbd 100644 --- a/etna/datasets/hierarchical_structure.py +++ b/etna/datasets/hierarchical_structure.py @@ -4,7 +4,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Union from scipy.sparse import csr_matrix from scipy.sparse import lil_matrix @@ -147,7 +146,7 @@ def get_summing_matrix(self, target_level: str, source_level: str) -> csr_matrix target_idx = self._level_to_index[target_level] source_idx = self._level_to_index[source_level] except KeyError as e: - raise ValueError("Invalid level name: " + e.args[0]) + raise ValueError(f"Invalid level name: {e.args[0]}") if target_idx >= source_idx: raise ValueError("Target level must be higher in hierarchy than source level!") @@ -172,10 +171,10 @@ def get_level_segments(self, level_name: str) -> List[str]: """Get all segments from particular level.""" try: return self._level_series[level_name] - except KeyError as e: - raise ValueError("Invalid level name: " + e.args[0]) + except KeyError: + raise ValueError(f"Invalid level name: {level_name}") - def get_segment_level(self, segment: str) -> Union[str, None]: + def get_segment_level(self, segment: str) -> str: """Get level name for provided segment.""" try: return self._segment_to_level[segment] diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 104a50d8a..e4238c1e4 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -93,6 +93,7 @@ def __init__( freq: str, df_exog: Optional[pd.DataFrame] = None, known_future: Union[Literal["all"], Sequence] = (), + hierarchical_structure: Optional[HierarchicalStructure] = None, ): """Init TSDataset. @@ -107,6 +108,8 @@ def __init__( known_future: columns in ``df_exog[known_future]`` that are regressors, if "all" value is given, all columns are meant to be regressors + hierarchical_structure: + Structure of the levels in the hierarchy. If None, there is no hierarchical structure in the dataset. """ self.raw_df = self._prepare_df(df) self.raw_df.index = pd.to_datetime(self.raw_df.index) @@ -133,13 +136,36 @@ def __init__( self.known_future = self._check_known_future(known_future, df_exog) self._regressors = copy(self.known_future) + self.hierarchical_structure = hierarchical_structure + self.current_df_level: Optional[str] = self._get_dataframe_level(df=self.df) + self.current_df_exog_level: Optional[str] = None + if df_exog is not None: self.df_exog = df_exog.copy(deep=True) self.df_exog.index = pd.to_datetime(self.df_exog.index) - self.df = self._merge_exog(self.df) + self.current_df_exog_level = self._get_dataframe_level(df=self.df_exog) + if self.current_df_level == self.current_df_exog_level: + self.df = self._merge_exog(self.df) self.transforms: Optional[Sequence["Transform"]] = None + def _get_dataframe_level(self, df: pd.DataFrame) -> Optional[str]: + """Return the level of the passed dataframe in hierarchical structure.""" + if self.hierarchical_structure is None: + return None + + df_segments = df.columns.get_level_values("segment").unique() + segment_levels = {self.hierarchical_structure.get_segment_level(segment=segment) for segment in df_segments} + if len(segment_levels) != 1: + raise ValueError("Segments in dataframe are from more than 1 hierarchical levels!") + + df_level = segment_levels.pop() + level_segments = self.hierarchical_structure.get_level_segments(level_name=df_level) + if len(df_segments) != len(level_segments): + raise ValueError("Some segments of hierarchical level are missing in dataframe!") + + return df_level + def transform(self, transforms: Sequence["Transform"]): """Apply given transform to the data.""" self._check_endings(warning=True) @@ -295,7 +321,7 @@ def make_future(self, future_steps: int, tail_steps: int = 0) -> "TSDataset": df = self.raw_df.reindex(new_index) df.index.name = "timestamp" - if self.df_exog is not None: + if self.df_exog is not None and self.current_df_level == self.current_df_exog_level: df = self._merge_exog(df) # check if we have enough values in regressors @@ -944,6 +970,16 @@ def index(self) -> pd.core.indexes.datetimes.DatetimeIndex: """ return self.df.index + def level_names(self) -> Optional[List[str]]: + """Return names of the levels in the hierarchical structure.""" + if self.hierarchical_structure is None: + return None + return self.hierarchical_structure.level_names + + def has_hierarchy(self) -> bool: + """Check whether dataset has hierarchical structure.""" + return self.hierarchical_structure is not None + @property def columns(self) -> pd.core.indexes.multi.MultiIndex: """Return columns of ``self.df``. diff --git a/tests/test_datasets/test_hierarchical_dataset.py b/tests/test_datasets/test_hierarchical_dataset.py index 6a8bfa6c9..754f4bd98 100644 --- a/tests/test_datasets/test_hierarchical_dataset.py +++ b/tests/test_datasets/test_hierarchical_dataset.py @@ -7,6 +7,15 @@ @pytest.fixture def hierarchical_structure(): + hs = HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + level_names=["total", "market", "product"], + ) + return hs + + +@pytest.fixture +def hierarchical_structure_complex(): hs = HierarchicalStructure( level_structure={ "total": ["77", "120"], @@ -20,6 +29,19 @@ def hierarchical_structure(): return hs +@pytest.fixture +def different_level_segments_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["a"] * 2, + "target": [1, 2] + [10, 20], + } + ) + df = TSDataset.to_dataset(df) + return df + + @pytest.fixture def level_columns_different_types_df(): df = pd.DataFrame( @@ -35,6 +57,19 @@ def level_columns_different_types_df(): return df +@pytest.fixture +def different_level_segments_df_exog(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["a"] * 2, + "exog": [1, 2] + [10, 20], + } + ) + df = TSDataset.to_dataset(df) + return df + + @pytest.fixture def product_level_df_long(): df = pd.DataFrame( @@ -48,6 +83,19 @@ def product_level_df_long(): return df +@pytest.fixture +def missing_segments_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"], + "segment": ["X"] * 2, + "target": [1, 2], + } + ) + df = TSDataset.to_dataset(df) + return df + + @pytest.fixture def product_level_df_wide(product_level_df_long): df = product_level_df_long @@ -56,6 +104,139 @@ def product_level_df_wide(product_level_df_long): return df +@pytest.fixture +def market_level_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "target": [1, 2] + [10, 20], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_df_exog(): + df_exog = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "exog": [1, 2] + [10, 20], + } + ) + df_exog = TSDataset.to_dataset(df_exog) + return df_exog + + +@pytest.fixture +def product_level_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "segment": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [1, 2] + [10, 20] + [100, 200] + [1000, 2000], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def simple_hierarchical_ts(market_level_df, hierarchical_structure): + df = market_level_df + ts = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +def test_get_dataframe_level_different_level_segments_fails(different_level_segments_df, simple_hierarchical_ts): + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + simple_hierarchical_ts._get_dataframe_level(df=different_level_segments_df) + + +def test_get_dataframe_level_missing_segments_fails(missing_segments_df, simple_hierarchical_ts): + with pytest.raises(ValueError, match="Some segments of hierarchical level are missing in dataframe!"): + simple_hierarchical_ts._get_dataframe_level(df=missing_segments_df) + + +@pytest.mark.parametrize("df, expected_level", [("market_level_df", "market"), ("product_level_df", "product")]) +def test_get_dataframe(df, expected_level, simple_hierarchical_ts, request): + df = request.getfixturevalue(df) + df_level = simple_hierarchical_ts._get_dataframe_level(df=df) + assert df_level == expected_level + + +def test_init_different_level_segments_df_fails(different_level_segments_df, hierarchical_structure): + df = different_level_segments_df + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + _ = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + + +def test_init_different_level_segments_df_exog_fails( + market_level_df, different_level_segments_df_exog, hierarchical_structure +): + df, df_exog = market_level_df, different_level_segments_df_exog + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + _ = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + + +def test_init_df_same_level_df_exog( + market_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target", "exog"} +): + df, df_exog = market_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + df_columns = set(ts.columns.get_level_values("feature")) + assert df_columns == expected_columns + + +def test_init_df_different_level_df_exog( + product_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target"} +): + df, df_exog = product_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + df_columns = set(ts.columns.get_level_values("feature")) + assert df_columns == expected_columns + + +def test_init_missing_segmnets_df(missing_segments_df, hierarchical_structure): + df = missing_segments_df + with pytest.raises(ValueError, match="Some segments of hierarchical level are missing in dataframe!"): + _ = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + + +def test_make_future_df_same_level_df_exog( + market_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target", "exog"} +): + df, df_exog = market_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + future = ts.make_future(future_steps=4) + future_columns = set(future.columns.get_level_values("feature")) + assert future_columns == expected_columns + + +def test_make_future_df_different_level_df_exog( + product_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target"} +): + df, df_exog = product_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + future = ts.make_future(future_steps=4) + future_columns = set(future.columns.get_level_values("feature")) + assert future_columns == expected_columns + + +def test_level_names_with_hierarchical_structure(simple_hierarchical_ts, expected_names=["total", "market", "product"]): + ts_level_names = simple_hierarchical_ts.level_names() + assert sorted(ts_level_names) == sorted(expected_names) + + +def test_level_names_without_hierarchical_structure(market_level_df): + df = market_level_df + ts = TSDataset(df=df, freq="D") + ts_level_names = ts.level_names() + assert ts_level_names is None + + def test_to_hierarchical_dataset_not_change_input_df(product_level_df_long): df = product_level_df_long df_before = df.copy() @@ -105,13 +286,15 @@ def test_to_hierarchical_dataset_correct_dataframe(product_level_df_long, produc pd.testing.assert_frame_equal(df_wide_obtained, product_level_df_wide) -def test_to_hierarchical_dataset_hierarchical_structure(level_columns_different_types_df, hierarchical_structure): +def test_to_hierarchical_dataset_hierarchical_structure( + level_columns_different_types_df, hierarchical_structure_complex +): _, hs = TSDataset.to_hierarchical_dataset( df=level_columns_different_types_df, level_columns=["categorical", "string", "int"], return_hierarchy=True ) - assert hs.level_names == hierarchical_structure.level_names - for level_name in hierarchical_structure.level_names: + assert hs.level_names == hierarchical_structure_complex.level_names + for level_name in hierarchical_structure_complex.level_names: assert level_name in hs.level_names - expected_level_segments = hierarchical_structure.get_level_segments(level_name=level_name) + expected_level_segments = hierarchical_structure_complex.get_level_segments(level_name=level_name) obtained_level_segments = hs.get_level_segments(level_name=level_name) assert sorted(obtained_level_segments) == sorted(expected_level_segments)