-
Notifications
You must be signed in to change notification settings - Fork 80
Merge hierarchical pipeline #1083
Merged
Merged
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
ebadcb2
`HierarchicalStructure` class (#1044)
brsnw250 62af639
Add `to_hierarchical_dataset` method to `TSDataset` (#1049)
alex-hse-repository 37ab295
Enhance `TSDataset` to work with hierarchical series (#1048)
alex-hse-repository 0d05503
Create `get_level_dataset` method (#1053)
brsnw250 363b013
Create `BaseReconciler` (#1054)
alex-hse-repository 6985a21
Create `TopDownReconciliator` (#1055)
brsnw250 5ca0eeb
Create `BottomUpReconciliator` (#1058)
brsnw250 5f4c17e
Create `generate_hierarchical_df` (#1060)
brsnw250 edd801d
Create `HierarchicalPipeline` (#1059)
brsnw250 455cdbb
Hierarchical pipeline example notebook (#1070)
brsnw250 2a183f2
Add prediction intervals support to HierarchicalPipeline (#1079)
brsnw250 70154ff
Merge branch 'master' into hierarchical_pipeline
alex-hse-repository da06d0a
Update changelog
alex-hse-repository d7fc425
Add n_periods to generate_hierarchical_df
alex-hse-repository 6d36a11
Add reises block
alex-hse-repository 51b2715
Fix notebook
alex-hse-repository 1c92f77
Fix notebook
alex-hse-repository c040785
Fix notebook
alex-hse-repository 6c03334
Fix notebook
alex-hse-repository 387a4ed
Add comments about random dataset generation
alex-hse-repository db7db80
Fix formating
alex-hse-repository 0844bd6
Fix docs
alex-hse-repository File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ Tutorials | |
tutorials/ensembles | ||
tutorials/NN_examples | ||
tutorials/classification | ||
tutorials/hierarchical_pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from etna.datasets.datasets_generation import generate_ar_df | ||
from etna.datasets.datasets_generation import generate_const_df | ||
from etna.datasets.datasets_generation import generate_from_patterns_df | ||
from etna.datasets.datasets_generation import generate_hierarchical_df | ||
from etna.datasets.datasets_generation import generate_periodic_df | ||
from etna.datasets.hierarchical_structure import HierarchicalStructure | ||
from etna.datasets.tsdataset import TSDataset | ||
from etna.datasets.utils import duplicate_data | ||
from etna.datasets.utils import set_columns_wide |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from collections import defaultdict | ||
from itertools import chain | ||
from queue import Queue | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Tuple | ||
|
||
from scipy.sparse import csr_matrix | ||
from scipy.sparse import lil_matrix | ||
|
||
from etna.core import BaseMixin | ||
|
||
|
||
class HierarchicalStructure(BaseMixin): | ||
"""Represents hierarchical structure of TSDataset.""" | ||
|
||
def __init__(self, level_structure: Dict[str, List[str]], level_names: Optional[List[str]] = None): | ||
"""Init HierarchicalStructure. | ||
|
||
Parameters | ||
---------- | ||
level_structure: | ||
Adjacency list describing the structure of the hierarchy tree (i.e. {"total":["X", "Y"], "X":["a", "b"], "Y":["c", "d"]}). | ||
level_names: | ||
Names of levels in the hierarchy in the order from top to bottom (i.e. ["total", "category", "product"]). | ||
If None is passed, level names are generated automatically with structure "level_<level_index>". | ||
""" | ||
self.level_structure = level_structure | ||
self._hierarchy_root = self._find_tree_root(self.level_structure) | ||
self._num_nodes = self._find_num_nodes(self.level_structure) | ||
|
||
hierarchy_levels = self._find_hierarchy_levels() | ||
tree_depth = len(hierarchy_levels) | ||
|
||
self.level_names = self._get_level_names(level_names, tree_depth) | ||
self._level_series: Dict[str, Tuple[str, ...]] = { | ||
self.level_names[i]: tuple(hierarchy_levels[i]) for i in range(tree_depth) | ||
} | ||
self._level_to_index: Dict[str, int] = {self.level_names[i]: i for i in range(tree_depth)} | ||
|
||
self._segment_num_reachable_leafs: Dict[str, int] = self._get_num_reachable_leafs(hierarchy_levels) | ||
|
||
self._segment_to_level: Dict[str, str] = { | ||
segment: level for level in self._level_series for segment in self._level_series[level] | ||
} | ||
|
||
@staticmethod | ||
def _get_level_names(level_names: Optional[List[str]], tree_depth: int) -> List[str]: | ||
"""Assign level names if not provided.""" | ||
if level_names is None: | ||
level_names = [f"level_{i}" for i in range(tree_depth)] | ||
|
||
if len(level_names) != tree_depth: | ||
raise ValueError("Length of `level_names` must be equal to hierarchy tree depth!") | ||
|
||
return level_names | ||
|
||
@staticmethod | ||
def _find_tree_root(hierarchy_structure: Dict[str, List[str]]) -> str: | ||
"""Find hierarchy top level (root of tree).""" | ||
children = set(chain(*hierarchy_structure.values())) | ||
parents = set(hierarchy_structure.keys()) | ||
|
||
tree_roots = parents.difference(children) | ||
if len(tree_roots) != 1: | ||
raise ValueError("Invalid tree definition: unable to find root!") | ||
|
||
return tree_roots.pop() | ||
|
||
@staticmethod | ||
def _find_num_nodes(hierarchy_structure: Dict[str, List[str]]) -> int: | ||
"""Count number of nodes in tree.""" | ||
children = set(chain(*hierarchy_structure.values())) | ||
parents = set(hierarchy_structure.keys()) | ||
|
||
num_nodes = len(children | parents) | ||
|
||
num_edges = sum(map(len, hierarchy_structure.values())) | ||
if num_edges != num_nodes - 1: | ||
raise ValueError("Invalid tree definition: invalid number of nodes and edges!") | ||
|
||
return num_nodes | ||
|
||
def _find_hierarchy_levels(self) -> Dict[int, List[str]]: | ||
"""Traverse hierarchy tree to group segments into levels.""" | ||
leaves_levels = set() | ||
levels = defaultdict(list) | ||
seen_nodes = {self._hierarchy_root} | ||
queue: Queue = Queue() | ||
queue.put((self._hierarchy_root, 0)) | ||
while not queue.empty(): | ||
node, level = queue.get() | ||
levels[level].append(node) | ||
child_nodes = self.level_structure.get(node, []) | ||
|
||
if len(child_nodes) == 0: | ||
leaves_levels.add(level) | ||
|
||
for adj_node in child_nodes: | ||
queue.put((adj_node, level + 1)) | ||
seen_nodes.add(adj_node) | ||
|
||
if len(seen_nodes) != self._num_nodes: | ||
raise ValueError("Invalid tree definition: disconnected graph!") | ||
|
||
if len(leaves_levels) != 1: | ||
raise ValueError("All hierarchy tree leaves must be on the same level!") | ||
|
||
return levels | ||
|
||
def _get_num_reachable_leafs(self, hierarchy_levels: Dict[int, List[str]]) -> Dict[str, int]: | ||
"""Compute subtree size for each node.""" | ||
num_reachable_leafs: Dict[str, int] = dict() | ||
for level in sorted(hierarchy_levels.keys(), reverse=True): | ||
for node in hierarchy_levels[level]: | ||
if node in self.level_structure: | ||
num_reachable_leafs[node] = sum( | ||
num_reachable_leafs[child_node] for child_node in self.level_structure[node] | ||
) | ||
|
||
else: | ||
num_reachable_leafs[node] = 1 | ||
|
||
return num_reachable_leafs | ||
|
||
def get_summing_matrix(self, target_level: str, source_level: str) -> csr_matrix: | ||
"""Get summing matrix for transition from source level to target level. | ||
|
||
Generation algorithm is based on summing matrix structure. Number of 1 in such matrices equals to | ||
number of nodes on the source level. Each row of summing matrices has ones only for source level nodes that | ||
belongs to subtree rooted from corresponding target level node. BFS order of nodes on levels view simplifies | ||
algorithm to calculation necessary offsets for each row. | ||
|
||
Parameters | ||
---------- | ||
target_level: | ||
Name of target level. | ||
source_level: | ||
Name of source level. | ||
|
||
Returns | ||
------- | ||
: | ||
Summing matrix from source level to target level | ||
|
||
""" | ||
try: | ||
target_idx = self._level_to_index[target_level] | ||
source_idx = self._level_to_index[source_level] | ||
except KeyError as e: | ||
raise ValueError(f"Invalid level name: {e.args[0]}") | ||
|
||
if target_idx > source_idx: | ||
raise ValueError("Target level must be higher or equal in hierarchy than source level!") | ||
|
||
target_level_segment = self.get_level_segments(target_level) | ||
source_level_segment = self.get_level_segments(source_level) | ||
summing_matrix = lil_matrix((len(target_level_segment), len(source_level_segment)), dtype="int32") | ||
|
||
current_source_segment_id = 0 | ||
for current_target_segment_id, segment in enumerate(target_level_segment): | ||
num_reachable_leafs_left = self._segment_num_reachable_leafs[segment] | ||
|
||
while num_reachable_leafs_left > 0: | ||
source_segment = source_level_segment[current_source_segment_id] | ||
num_reachable_leafs_left -= self._segment_num_reachable_leafs[source_segment] | ||
summing_matrix[current_target_segment_id, current_source_segment_id] = 1 | ||
current_source_segment_id += 1 | ||
|
||
return summing_matrix.tocsr() | ||
|
||
def get_level_segments(self, level_name: str) -> List[str]: | ||
"""Get all segments from particular level.""" | ||
try: | ||
return list(self._level_series[level_name]) | ||
except KeyError: | ||
raise ValueError(f"Invalid level name: {level_name}") | ||
|
||
def get_segment_level(self, segment: str) -> str: | ||
"""Get level name for provided segment.""" | ||
try: | ||
return self._segment_to_level[segment] | ||
except KeyError: | ||
raise ValueError(f"Segment {segment} is out of the hierarchy") | ||
|
||
def get_level_depth(self, level_name: str) -> int: | ||
"""Get level depth in a hierarchy tree.""" | ||
try: | ||
return self._level_to_index[level_name] | ||
except KeyError: | ||
raise ValueError(f"Invalid level name: {level_name}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From method name and docstring I can't understand how many levels are generated during this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
n_levels
parameter to make it more explicit