Skip to content

Commit

Permalink
Create BottomUpReconciliator (#1058)
Browse files Browse the repository at this point in the history
  • Loading branch information
brsnw250 authored Jan 3, 2023
1 parent 6985a21 commit 5ca0eeb
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 52 deletions.
1 change: 1 addition & 0 deletions etna/reconciliation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from etna.reconciliation.base import BaseReconciliator
from etna.reconciliation.bottom_up import BottomUpReconciliator
from etna.reconciliation.top_down import TopDownReconciliator
50 changes: 50 additions & 0 deletions etna/reconciliation/bottom_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from etna.datasets import TSDataset
from etna.reconciliation.base import BaseReconciliator


class BottomUpReconciliator(BaseReconciliator):
"""Bottom-up reconciliation."""

def __init__(self, target_level: str, source_level: str):
"""Create bottom-up reconciliator from ``source_level`` to ``target_level``.
Parameters
----------
target_level:
Level to be reconciled from the forecasts.
source_level:
Level to be forecasted.
"""
super().__init__(target_level=target_level, source_level=source_level)

def fit(self, ts: TSDataset) -> "BottomUpReconciliator":
"""Fit the reconciliator parameters.
Parameters
----------
ts:
TSDataset on the level which is lower or equal to ``target_level``, ``source_level``.
Returns
-------
:
Fitted instance of reconciliator.
"""
if ts.hierarchical_structure is None:
raise ValueError(f"The method can be applied only to instances with a hierarchy!")

current_level_index = ts.hierarchical_structure.get_level_depth(ts.current_df_level) # type: ignore
source_level_index = ts.hierarchical_structure.get_level_depth(self.source_level)
target_level_index = ts.hierarchical_structure.get_level_depth(self.target_level)

if source_level_index < target_level_index:
raise ValueError("Source level should be lower or equal in the hierarchy than the target level!")

if current_level_index < source_level_index:
raise ValueError("Current TSDataset level should be lower or equal in the hierarchy than the source level!")

self.mapping_matrix = ts.hierarchical_structure.get_summing_matrix(
target_level=self.target_level, source_level=self.source_level
)

return self
7 changes: 6 additions & 1 deletion etna/reconciliation/top_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def _missing_(cls, method):


class TopDownReconciliator(BaseReconciliator):
"""Top-down reconciliation methods."""
"""Top-down reconciliation methods.
Notes
-----
Top-down reconciliation methods support only non-negative data.
"""

def __init__(self, target_level: str, source_level: str, period: int, method: str):
"""Create top-down reconciliator from ``source_level`` to ``target_level``.
Expand Down
53 changes: 53 additions & 0 deletions tests/test_reconciliation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset


@pytest.fixture
def market_level_df_w_negatives():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"target": [-1.0, 2.0] + [0, -20.0],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure):
ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure):
ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure):
ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_no_hierarchy_ts(market_level_df):
ts = TSDataset(df=market_level_df, freq="D")
return ts
98 changes: 98 additions & 0 deletions tests/test_reconciliation/test_bottom_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pytest

from etna.reconciliation.bottom_up import BottomUpReconciliator


@pytest.mark.parametrize(
"target_level,source_level,error_message",
(
(
"market",
"total",
"Source level should be lower or equal in the hierarchy than the target level!",
),
(
"total",
"product",
"Current TSDataset level should be lower or equal in the hierarchy than the source level!",
),
),
)
def test_bottom_up_reconcile_level_order_errors(
market_level_simple_hierarchical_ts, target_level, source_level, error_message
):
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
with pytest.raises(ValueError, match=error_message):
reconciler.fit(market_level_simple_hierarchical_ts)


@pytest.mark.parametrize(
"target_level,source_level",
(("abc", "total"), ("market", "abc")),
)
def test_bottom_up_reconcile_invalid_level_errors(market_level_simple_hierarchical_ts, target_level, source_level):
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
with pytest.raises(ValueError, match="Invalid level name: abc"):
reconciler.fit(market_level_simple_hierarchical_ts)


def test_bottom_up_reconcile_no_hierarchy_error(simple_no_hierarchy_ts):
reconciler = BottomUpReconciliator(target_level="market", source_level="total")
with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"):
reconciler.fit(simple_no_hierarchy_ts)


@pytest.mark.parametrize(
"ts_name,target_level,source_level,answer",
(
(
"product_level_simple_hierarchical_ts",
"product",
"product",
np.identity(4),
),
(
"product_level_simple_hierarchical_ts",
"market",
"product",
np.array([[1, 1, 0, 0], [0, 0, 1, 1]]),
),
(
"product_level_simple_hierarchical_ts",
"total",
"product",
np.array([[1, 1, 1, 1]]),
),
(
"product_level_simple_hierarchical_ts",
"total",
"market",
np.array([[1, 1]]),
),
(
"market_level_simple_hierarchical_ts",
"total",
"market",
np.array([[1, 1]]),
),
(
"total_level_simple_hierarchical_ts",
"total",
"total",
np.array([[1]]),
),
),
)
def test_bottom_up_reconcile_fit(ts_name, target_level, source_level, answer, request):
ts = request.getfixturevalue(ts_name)
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level)
reconciler.fit(ts)
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4)


def test_bottom_up_reconcile_fit_w_nans(market_level_simple_hierarchical_ts_w_nans):
answer = np.array([[1, 1]])
reconciler = BottomUpReconciliator(source_level="market", target_level="total")
reconciler.fit(market_level_simple_hierarchical_ts_w_nans)
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4)
51 changes: 0 additions & 51 deletions tests/test_reconciliation/test_top_down.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,9 @@
import numpy as np
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.reconciliation.top_down import TopDownReconciliator


@pytest.fixture
def market_level_df_w_negatives():
df = pd.DataFrame(
{
"timestamp": ["2000-01-01", "2000-01-02"] * 2,
"segment": ["X"] * 2 + ["Y"] * 2,
"target": [-1.0, 2.0] + [0, -20.0],
}
)
df = TSDataset.to_dataset(df)
return df


@pytest.fixture
def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure):
ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure):
ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure):
ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure):
ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure)
return ts


@pytest.fixture
def simple_no_hierarchy_ts(market_level_df):
ts = TSDataset(df=market_level_df, freq="D")
return ts


@pytest.mark.parametrize(
"period",
(0, -1),
Expand Down

0 comments on commit 5ca0eeb

Please sign in to comment.