-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create
BottomUpReconciliator
(#1058)
- Loading branch information
Showing
6 changed files
with
208 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from etna.reconciliation.base import BaseReconciliator | ||
from etna.reconciliation.bottom_up import BottomUpReconciliator | ||
from etna.reconciliation.top_down import TopDownReconciliator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from etna.datasets import TSDataset | ||
from etna.reconciliation.base import BaseReconciliator | ||
|
||
|
||
class BottomUpReconciliator(BaseReconciliator): | ||
"""Bottom-up reconciliation.""" | ||
|
||
def __init__(self, target_level: str, source_level: str): | ||
"""Create bottom-up reconciliator from ``source_level`` to ``target_level``. | ||
Parameters | ||
---------- | ||
target_level: | ||
Level to be reconciled from the forecasts. | ||
source_level: | ||
Level to be forecasted. | ||
""" | ||
super().__init__(target_level=target_level, source_level=source_level) | ||
|
||
def fit(self, ts: TSDataset) -> "BottomUpReconciliator": | ||
"""Fit the reconciliator parameters. | ||
Parameters | ||
---------- | ||
ts: | ||
TSDataset on the level which is lower or equal to ``target_level``, ``source_level``. | ||
Returns | ||
------- | ||
: | ||
Fitted instance of reconciliator. | ||
""" | ||
if ts.hierarchical_structure is None: | ||
raise ValueError(f"The method can be applied only to instances with a hierarchy!") | ||
|
||
current_level_index = ts.hierarchical_structure.get_level_depth(ts.current_df_level) # type: ignore | ||
source_level_index = ts.hierarchical_structure.get_level_depth(self.source_level) | ||
target_level_index = ts.hierarchical_structure.get_level_depth(self.target_level) | ||
|
||
if source_level_index < target_level_index: | ||
raise ValueError("Source level should be lower or equal in the hierarchy than the target level!") | ||
|
||
if current_level_index < source_level_index: | ||
raise ValueError("Current TSDataset level should be lower or equal in the hierarchy than the source level!") | ||
|
||
self.mapping_matrix = ts.hierarchical_structure.get_summing_matrix( | ||
target_level=self.target_level, source_level=self.source_level | ||
) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from etna.datasets import TSDataset | ||
|
||
|
||
@pytest.fixture | ||
def market_level_df_w_negatives(): | ||
df = pd.DataFrame( | ||
{ | ||
"timestamp": ["2000-01-01", "2000-01-02"] * 2, | ||
"segment": ["X"] * 2 + ["Y"] * 2, | ||
"target": [-1.0, 2.0] + [0, -20.0], | ||
} | ||
) | ||
df = TSDataset.to_dataset(df) | ||
return df | ||
|
||
|
||
@pytest.fixture | ||
def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure): | ||
ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure) | ||
return ts | ||
|
||
|
||
@pytest.fixture | ||
def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure): | ||
ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure) | ||
return ts | ||
|
||
|
||
@pytest.fixture | ||
def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure): | ||
ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure) | ||
return ts | ||
|
||
|
||
@pytest.fixture | ||
def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure): | ||
ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure) | ||
return ts | ||
|
||
|
||
@pytest.fixture | ||
def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure): | ||
ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure) | ||
return ts | ||
|
||
|
||
@pytest.fixture | ||
def simple_no_hierarchy_ts(market_level_df): | ||
ts = TSDataset(df=market_level_df, freq="D") | ||
return ts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from etna.reconciliation.bottom_up import BottomUpReconciliator | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_level,source_level,error_message", | ||
( | ||
( | ||
"market", | ||
"total", | ||
"Source level should be lower or equal in the hierarchy than the target level!", | ||
), | ||
( | ||
"total", | ||
"product", | ||
"Current TSDataset level should be lower or equal in the hierarchy than the source level!", | ||
), | ||
), | ||
) | ||
def test_bottom_up_reconcile_level_order_errors( | ||
market_level_simple_hierarchical_ts, target_level, source_level, error_message | ||
): | ||
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) | ||
with pytest.raises(ValueError, match=error_message): | ||
reconciler.fit(market_level_simple_hierarchical_ts) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_level,source_level", | ||
(("abc", "total"), ("market", "abc")), | ||
) | ||
def test_bottom_up_reconcile_invalid_level_errors(market_level_simple_hierarchical_ts, target_level, source_level): | ||
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) | ||
with pytest.raises(ValueError, match="Invalid level name: abc"): | ||
reconciler.fit(market_level_simple_hierarchical_ts) | ||
|
||
|
||
def test_bottom_up_reconcile_no_hierarchy_error(simple_no_hierarchy_ts): | ||
reconciler = BottomUpReconciliator(target_level="market", source_level="total") | ||
with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"): | ||
reconciler.fit(simple_no_hierarchy_ts) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"ts_name,target_level,source_level,answer", | ||
( | ||
( | ||
"product_level_simple_hierarchical_ts", | ||
"product", | ||
"product", | ||
np.identity(4), | ||
), | ||
( | ||
"product_level_simple_hierarchical_ts", | ||
"market", | ||
"product", | ||
np.array([[1, 1, 0, 0], [0, 0, 1, 1]]), | ||
), | ||
( | ||
"product_level_simple_hierarchical_ts", | ||
"total", | ||
"product", | ||
np.array([[1, 1, 1, 1]]), | ||
), | ||
( | ||
"product_level_simple_hierarchical_ts", | ||
"total", | ||
"market", | ||
np.array([[1, 1]]), | ||
), | ||
( | ||
"market_level_simple_hierarchical_ts", | ||
"total", | ||
"market", | ||
np.array([[1, 1]]), | ||
), | ||
( | ||
"total_level_simple_hierarchical_ts", | ||
"total", | ||
"total", | ||
np.array([[1]]), | ||
), | ||
), | ||
) | ||
def test_bottom_up_reconcile_fit(ts_name, target_level, source_level, answer, request): | ||
ts = request.getfixturevalue(ts_name) | ||
reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) | ||
reconciler.fit(ts) | ||
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) | ||
|
||
|
||
def test_bottom_up_reconcile_fit_w_nans(market_level_simple_hierarchical_ts_w_nans): | ||
answer = np.array([[1, 1]]) | ||
reconciler = BottomUpReconciliator(source_level="market", target_level="total") | ||
reconciler.fit(market_level_simple_hierarchical_ts_w_nans) | ||
np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters