Skip to content

Create BottomUpReconciliator #1058

Merged
merged 5 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions etna/reconciliation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from etna.reconciliation.base import BaseReconciliator
from etna.reconciliation.bottom_up import BottomUpReconciliator
from etna.reconciliation.top_down import TopDownReconciliator
50 changes: 50 additions & 0 deletions etna/reconciliation/bottom_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from etna.datasets import TSDataset
from etna.reconciliation.base import BaseReconciliator


class BottomUpReconciliator(BaseReconciliator):
"""Bottom-up reconciliation."""

def __init__(self, target_level: str, source_level: str):
"""Create bottom-up reconciliator from ``source_level`` to ``target_level``.

Parameters
----------
target_level:
Level to be reconciled from the forecasts.
source_level:
Level to be forecasted.
"""
super().__init__(target_level=target_level, source_level=source_level)

def fit(self, ts: TSDataset) -> "BottomUpReconciliator":
"""Fit the reconciliator parameters.

Parameters
----------
ts:
TSDataset on the level which is lower or equal to ``target_level``, ``source_level``.

Returns
-------
:
Fitted instance of reconciliator.
"""
if ts.hierarchical_structure is None:
raise ValueError(f"The method can be applied only to instances with a hierarchy!")

current_level_index = ts.hierarchical_structure.get_level_depth(ts.current_df_level) # type: ignore
source_level_index = ts.hierarchical_structure.get_level_depth(self.source_level)
target_level_index = ts.hierarchical_structure.get_level_depth(self.target_level)

if source_level_index < target_level_index:
raise ValueError("Source level should be lower or equal in the hierarchy than the target level!")

if current_level_index < source_level_index:
raise ValueError("Current TSDataset level should be lower or equal in the hierarchy than the source level!")

self.mapping_matrix = ts.hierarchical_structure.get_summing_matrix(
target_level=self.target_level, source_level=self.source_level
)

return self
7 changes: 6 additions & 1 deletion etna/reconciliation/top_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def _missing_(cls, method):


class TopDownReconciliator(BaseReconciliator):
"""Top-down reconciliation methods."""
"""Top-down reconciliation methods.

Notes
-----
Top-down reconciliation methods support only non-negative data.
"""

def __init__(self, target_level: str, source_level: str, period: int, method: str):
"""Create top-down reconciliator from ``source_level`` to ``target_level``.
Expand Down
53 changes: 53 additions & 0 deletions tests/test_reconciliation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset


@pytest.fixture
def market_level_df_w_negatives():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"target": [-1.0, 2.0] + [0, -20.0],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure):
ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure):
ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure):
ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_no_hierarchy_ts(market_level_df):
ts = TSDataset(df=market_level_df, freq="D")
return ts
98 changes: 98 additions & 0 deletions tests/test_reconciliation/test_bottom_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pytest

from etna.reconciliation.bottom_up import BottomUpReconciliator


@pytest.mark.parametrize(
"target_level,source_level,error_message",
(
(
"market",
"total",
"Source level should be lower or equal in the hierarchy than the target level!",
),
(
"total",
"product",
"Current TSDataset level should be lower or equal in the hierarchy than the source level!",
),
),
)
def test_bottom_up_reconcile_level_order_errors(
market_level_simple_hierarchical_ts, target_level, source_level, error_message
):
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
with pytest.raises(ValueError, match=error_message):
reconciler.fit(market_level_simple_hierarchical_ts)


@pytest.mark.parametrize(
"target_level,source_level",
(("abc", "total"), ("market", "abc")),
)
def test_bottom_up_reconcile_invalid_level_errors(market_level_simple_hierarchical_ts, target_level, source_level):
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
with pytest.raises(ValueError, match="Invalid level name: abc"):
reconciler.fit(market_level_simple_hierarchical_ts)


def test_bottom_up_reconcile_no_hierarchy_error(simple_no_hierarchy_ts):
reconciler = BottomUpReconciliator(target_level="market", source_level="total")
with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"):
reconciler.fit(simple_no_hierarchy_ts)


@pytest.mark.parametrize(
"ts_name,target_level,source_level,answer",
(
(
"product_level_simple_hierarchical_ts",
"product",
"product",
np.identity(4),
),
(
"product_level_simple_hierarchical_ts",
"market",
"product",
np.array([[1, 1, 0, 0], [0, 0, 1, 1]]),
),
(
"product_level_simple_hierarchical_ts",
"total",
"product",
np.array([[1, 1, 1, 1]]),
),
(
"product_level_simple_hierarchical_ts",
"total",
"market",
np.array([[1, 1]]),
),
(
"market_level_simple_hierarchical_ts",
"total",
"market",
np.array([[1, 1]]),
),
(
"total_level_simple_hierarchical_ts",
"total",
"total",
np.array([[1]]),
),
),
)
def test_bottom_up_reconcile_fit(ts_name, target_level, source_level, answer, request):
ts = request.getfixturevalue(ts_name)
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
reconciler.fit(ts)
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4)


def test_bottom_up_reconcile_fit_w_nans(market_level_simple_hierarchical_ts_w_nans):
answer = np.array([[1, 1]])
reconciler = BottomUpReconciliator(source_level="market", target_level="total")
reconciler.fit(market_level_simple_hierarchical_ts_w_nans)
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4)
51 changes: 0 additions & 51 deletions tests/test_reconciliation/test_top_down.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,9 @@
import numpy as np
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.reconciliation.top_down import TopDownReconciliator


@pytest.fixture
def market_level_df_w_negatives():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"target": [-1.0, 2.0] + [0, -20.0],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure):
ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure):
ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure):
ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_no_hierarchy_ts(market_level_df):
ts = TSDataset(df=market_level_df, freq="D")
return ts


@pytest.mark.parametrize(
"period",
(0, -1),
Expand Down