Skip to content

Enhance TSDataset to work with hierarchical series #1048

Merged
merged 18 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions etna/datasets/hierarchical_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from scipy.sparse import csr_matrix
from scipy.sparse import lil_matrix
Expand Down Expand Up @@ -147,7 +146,7 @@ def get_summing_matrix(self, target_level: str, source_level: str) -> csr_matrix
target_idx = self._level_to_index[target_level]
source_idx = self._level_to_index[source_level]
except KeyError as e:
raise ValueError("Invalid level name: " + e.args[0])
raise ValueError(f"Invalid level name: {e.args[0]}")

if target_idx >= source_idx:
raise ValueError("Target level must be higher in hierarchy than source level!")
Expand All @@ -172,10 +171,10 @@ def get_level_segments(self, level_name: str) -> List[str]:
"""Get all segments from particular level."""
try:
return self._level_series[level_name]
except KeyError as e:
raise ValueError("Invalid level name: " + e.args[0])
except KeyError:
raise ValueError(f"Invalid level name: {level_name}")

def get_segment_level(self, segment: str) -> Union[str, None]:
def get_segment_level(self, segment: str) -> str:
"""Get level name for provided segment."""
try:
return self._segment_to_level[segment]
Expand Down
37 changes: 35 additions & 2 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing_extensions import Literal

from etna import SETTINGS
from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.utils import _TorchDataset
from etna.loggers import tslogger

Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
freq: str,
df_exog: Optional[pd.DataFrame] = None,
known_future: Union[Literal["all"], Sequence] = (),
hierarchical_structure: Optional[HierarchicalStructure] = None,
):
"""Init TSDataset.

Expand All @@ -106,6 +108,8 @@ def __init__(
known_future:
columns in ``df_exog[known_future]`` that are regressors,
if "all" value is given, all columns are meant to be regressors
hierarchical_structure:
Structure of the levels in the hierarchy. If None, there is no hierarchical structure in the dataset.
"""
self.raw_df = self._prepare_df(df)
self.raw_df.index = pd.to_datetime(self.raw_df.index)
Expand All @@ -132,13 +136,36 @@ def __init__(
self.known_future = self._check_known_future(known_future, df_exog)
self._regressors = copy(self.known_future)

self.hierarchical_structure = hierarchical_structure
self.current_df_level: Optional[str] = self._get_dataframe_level(df=self.df)
self.current_df_exog_level: Optional[str] = None

if df_exog is not None:
self.df_exog = df_exog.copy(deep=True)
self.df_exog.index = pd.to_datetime(self.df_exog.index)
self.df = self._merge_exog(self.df)
self.current_df_exog_level = self._get_dataframe_level(df=self.df_exog)
if self.current_df_level == self.current_df_exog_level:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
self.df = self._merge_exog(self.df)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

self.transforms: Optional[Sequence["Transform"]] = None

def _get_dataframe_level(self, df: pd.DataFrame) -> Optional[str]:
"""Return the level of the passed dataframe in hierarchical structure."""
if self.hierarchical_structure is None:
return None

df_segments = df.columns.get_level_values("segment").unique()
segment_levels = {self.hierarchical_structure.get_segment_level(segment=segment) for segment in df_segments}
if len(segment_levels) != 1:
raise ValueError("Segments in dataframe are from more than 1 hierarchical levels!")

df_level = segment_levels.pop()
level_segments = self.hierarchical_structure.get_level_segments(level_name=df_level)
if len(df_segments) != len(level_segments):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this check should be here or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think so? You mean that it is not the real case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really think that it should be in exactly this function. Check is useful, but it looks somewhat artificial in this function because we have already got that we wanted here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my problem with this function is what it checks this conditions every run. But according to the name it should just return the value and don't check correctness during each run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm generally ok with this, it is optional to fix.

raise ValueError("Some segments of hierarchical level are missing in dataframe!")

return df_level

def transform(self, transforms: Sequence["Transform"]):
"""Apply given transform to the data."""
self._check_endings(warning=True)
Expand Down Expand Up @@ -294,7 +321,7 @@ def make_future(self, future_steps: int, tail_steps: int = 0) -> "TSDataset":
df = self.raw_df.reindex(new_index)
df.index.name = "timestamp"

if self.df_exog is not None:
if self.df_exog is not None and self.current_df_level == self.current_df_exog_level:
df = self._merge_exog(df)

# check if we have enough values in regressors
Expand Down Expand Up @@ -871,6 +898,12 @@ def index(self) -> pd.core.indexes.datetimes.DatetimeIndex:
"""
return self.df.index

def level_names(self) -> Optional[List[str]]:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Return names of the levels in the hierarchical structure."""
if self.hierarchical_structure is None:
return None
return self.hierarchical_structure.level_names

@property
def columns(self) -> pd.core.indexes.multi.MultiIndex:
"""Return columns of ``self.df``.
Expand Down
169 changes: 169 additions & 0 deletions tests/test_datasets/test_hierarchical_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import pandas as pd
import pytest

from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset


@pytest.fixture
def hierarchical_structure():
hs = HierarchicalStructure(
level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]},
level_names=["total", "market", "product"],
)
return hs


@pytest.fixture
def diff_level_segments_df():
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["a"] * 2,
"target": [1, 2] + [10, 20],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def diff_level_segments_df_exog():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["a"] * 2,
"exog": [1, 2] + [10, 20],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def missing_segments_df():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"],
"segment": ["X"] * 2,
"target": [1, 2],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def market_level_df():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"target": [1, 2] + [10, 20],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def market_level_df_exog():
df_exog = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"exog": [1, 2] + [10, 20],
}
)
df_exog = TSDataset.to_dataset(df_exog)
return df_exog


@pytest.fixture
def product_level_df():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 4,
"segment": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2,
"target": [1, 2] + [10, 20] + [100, 200] + [1000, 2000],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def simple_hierarchical_ts(market_level_df, hierarchical_structure):
df = market_level_df
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
ts = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


def test_get_dataframe_level_diff_level_segments_fails(diff_level_segments_df, simple_hierarchical_ts):
with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"):
simple_hierarchical_ts._get_dataframe_level(df=diff_level_segments_df)


def test_get_dataframe_level_missing_segments_fails(missing_segments_df, simple_hierarchical_ts):
with pytest.raises(ValueError, match="Some segments of hierarchical level are missing in dataframe!"):
simple_hierarchical_ts._get_dataframe_level(df=missing_segments_df)


@pytest.mark.parametrize("df, expected_level", [("market_level_df", "market"), ("product_level_df", "product")])
def test_get_dataframe(df, expected_level, simple_hierarchical_ts, request):
df = request.getfixturevalue(df)
df_level = simple_hierarchical_ts._get_dataframe_level(df=df)
assert df_level == expected_level


def test_init_diff_level_segments_df_fails(diff_level_segments_df, hierarchical_structure):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
df = diff_level_segments_df
with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"):
_ = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure)


def test_init_diff_level_segments_df_exog_fails(market_level_df, diff_level_segments_df_exog, hierarchical_structure):
df, df_exog = market_level_df, diff_level_segments_df_exog
with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"):
_ = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure)


def test_init_df_same_level_df_exog(market_level_df, market_level_df_exog, hierarchical_structure):
df, df_exog = market_level_df, market_level_df_exog
ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure)
df_columns = set(ts.columns.get_level_values("feature"))
assert df_columns == {"target", "exog"}


def test_init_df_diff_level_df_exog(product_level_df, market_level_df_exog, hierarchical_structure):
df, df_exog = product_level_df, market_level_df_exog
ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure)
df_columns = set(ts.columns.get_level_values("feature"))
assert df_columns == {"target"}


def test_make_future_df_same_level_df_exog(market_level_df, market_level_df_exog, hierarchical_structure):
df, df_exog = market_level_df, market_level_df_exog
ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure)
future = ts.make_future(future_steps=4)
future_columns = set(future.columns.get_level_values("feature"))
assert future_columns == {"target", "exog"}


def test_make_future_df_diff_level_df_exog(product_level_df, market_level_df_exog, hierarchical_structure):
df, df_exog = product_level_df, market_level_df_exog
ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure)
future = ts.make_future(future_steps=4)
future_columns = set(future.columns.get_level_values("feature"))
assert future_columns == {"target"}


alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
def test_level_names_with_hierarchical_structure(simple_hierarchical_ts, expected_names=["total", "market", "product"]):
ts_level_names = simple_hierarchical_ts.level_names()
assert sorted(ts_level_names) == sorted(expected_names)


def test_level_names_without_hierarchical_structure(market_level_df):
ts = TSDataset(df=market_level_df, freq="D")
ts_level_names = ts.level_names()
assert ts_level_names is None