diff --git a/etna/datasets/__init__.py b/etna/datasets/__init__.py index 9ef938d77..e44b31073 100644 --- a/etna/datasets/__init__.py +++ b/etna/datasets/__init__.py @@ -2,6 +2,7 @@ from etna.datasets.datasets_generation import generate_const_df from etna.datasets.datasets_generation import generate_from_patterns_df from etna.datasets.datasets_generation import generate_periodic_df +from etna.datasets.hierarchical_structure import HierarchicalStructure from etna.datasets.tsdataset import TSDataset from etna.datasets.utils import duplicate_data from etna.datasets.utils import set_columns_wide diff --git a/etna/datasets/hierarchical_structure.py b/etna/datasets/hierarchical_structure.py new file mode 100644 index 000000000..c142d0a19 --- /dev/null +++ b/etna/datasets/hierarchical_structure.py @@ -0,0 +1,183 @@ +from collections import defaultdict +from itertools import chain +from queue import Queue +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from scipy.sparse import csr_matrix +from scipy.sparse import lil_matrix + +from etna.core import BaseMixin + + +class HierarchicalStructure(BaseMixin): + """Represents hierarchical structure of TSDataset.""" + + def __init__(self, level_structure: Dict[str, List[str]], level_names: Optional[List[str]] = None): + """Init HierarchicalStructure. + + Parameters + ---------- + level_structure: + Adjacency list describing the structure of the hierarchy tree (i.e. {"total":["X", "Y"], "X":["a", "b"], "Y":["c", "d"]}). + level_names: + Names of levels in the hierarchy in the order from top to bottom (i.e. ["total", "category", "product"]). + If None is passed, level names are generated automatically with structure "level_". + """ + self.level_structure = level_structure + self._hierarchy_root = self._find_tree_root(self.level_structure) + self._num_nodes = self._find_num_nodes(self.level_structure) + + hierarchy_levels = self._find_hierarchy_levels() + tree_depth = len(hierarchy_levels) + + self.level_names = self._get_level_names(level_names, tree_depth) + self._level_series: Dict[str, List[str]] = {self.level_names[i]: hierarchy_levels[i] for i in range(tree_depth)} + self._level_to_index: Dict[str, int] = {self.level_names[i]: i for i in range(tree_depth)} + + self._segment_num_reachable_leafs: Dict[str, int] = self._get_num_reachable_leafs(hierarchy_levels) + + self._segment_to_level: Dict[str, str] = { + segment: level for level in self._level_series for segment in self._level_series[level] + } + + @staticmethod + def _get_level_names(level_names: Optional[List[str]], tree_depth: int) -> List[str]: + """Assign level names if not provided.""" + if level_names is None: + level_names = [f"level_{i}" for i in range(tree_depth)] + + if len(level_names) != tree_depth: + raise ValueError("Length of `level_names` must be equal to hierarchy tree depth!") + + return level_names + + @staticmethod + def _find_tree_root(hierarchy_structure: Dict[str, List[str]]) -> str: + """Find hierarchy top level (root of tree).""" + children = set(chain(*hierarchy_structure.values())) + parents = set(hierarchy_structure.keys()) + + tree_roots = parents.difference(children) + if len(tree_roots) != 1: + raise ValueError("Invalid tree definition: unable to find root!") + + return tree_roots.pop() + + @staticmethod + def _find_num_nodes(hierarchy_structure: Dict[str, List[str]]) -> int: + """Count number of nodes in tree.""" + children = set(chain(*hierarchy_structure.values())) + parents = set(hierarchy_structure.keys()) + + num_nodes = len(children | parents) + + num_edges = sum(map(len, hierarchy_structure.values())) + if num_edges != num_nodes - 1: + raise ValueError("Invalid tree definition: invalid number of nodes and edges!") + + return num_nodes + + def _find_hierarchy_levels(self) -> Dict[int, List[str]]: + """Traverse hierarchy tree to group segments into levels.""" + leaves_levels = set() + levels = defaultdict(list) + seen_nodes = {self._hierarchy_root} + queue: Queue = Queue() + queue.put((self._hierarchy_root, 0)) + while not queue.empty(): + node, level = queue.get() + levels[level].append(node) + child_nodes = self.level_structure.get(node, []) + + if len(child_nodes) == 0: + leaves_levels.add(level) + + for adj_node in child_nodes: + queue.put((adj_node, level + 1)) + seen_nodes.add(adj_node) + + if len(seen_nodes) != self._num_nodes: + raise ValueError("Invalid tree definition: disconnected graph!") + + if len(leaves_levels) != 1: + raise ValueError("All hierarchy tree leaves must be on the same level!") + + return levels + + def _get_num_reachable_leafs(self, hierarchy_levels: Dict[int, List[str]]) -> Dict[str, int]: + """Compute subtree size for each node.""" + num_reachable_leafs: Dict[str, int] = dict() + for level in sorted(hierarchy_levels.keys(), reverse=True): + for node in hierarchy_levels[level]: + if node in self.level_structure: + num_reachable_leafs[node] = sum( + num_reachable_leafs[child_node] for child_node in self.level_structure[node] + ) + + else: + num_reachable_leafs[node] = 1 + + return num_reachable_leafs + + def get_summing_matrix(self, target_level: str, source_level: str) -> csr_matrix: + """Get summing matrix for transition from source level to target level. + + Generation algorithm is based on summing matrix structure. Number of 1 in such matrices equals to + number of nodes on the source level. Each row of summing matrices has ones only for source level nodes that + belongs to subtree rooted from corresponding target level node. BFS order of nodes on levels view simplifies + algorithm to calculation necessary offsets for each row. + + Parameters + ---------- + target_level: + Name of target level. + source_level: + Name of source level. + + Returns + ------- + : + Summing matrix from source level to target level + + """ + try: + target_idx = self._level_to_index[target_level] + source_idx = self._level_to_index[source_level] + except KeyError as e: + raise ValueError("Invalid level name: " + e.args[0]) + + if target_idx >= source_idx: + raise ValueError("Target level must be higher in hierarchy than source level!") + + target_level_segment = self.get_level_segments(target_level) + source_level_segment = self.get_level_segments(source_level) + summing_matrix = lil_matrix((len(target_level_segment), len(source_level_segment))) + + current_source_segment_id = 0 + for current_target_segment_id, segment in enumerate(target_level_segment): + num_reachable_leafs_left = self._segment_num_reachable_leafs[segment] + + while num_reachable_leafs_left > 0: + source_segment = source_level_segment[current_source_segment_id] + num_reachable_leafs_left -= self._segment_num_reachable_leafs[source_segment] + summing_matrix[current_target_segment_id, current_source_segment_id] = 1 + current_source_segment_id += 1 + + return summing_matrix.tocsr() + + def get_level_segments(self, level_name: str) -> List[str]: + """Get all segments from particular level.""" + try: + return self._level_series[level_name] + except KeyError as e: + raise ValueError("Invalid level name: " + e.args[0]) + + def get_segment_level(self, segment: str) -> Union[str, None]: + """Get level name for provided segment.""" + try: + return self._segment_to_level[segment] + except KeyError: + raise ValueError(f"Segment {segment} is out of the hierarchy") diff --git a/tests/test_datasets/test_hierarchical_structure.py b/tests/test_datasets/test_hierarchical_structure.py new file mode 100644 index 000000000..3a25223a8 --- /dev/null +++ b/tests/test_datasets/test_hierarchical_structure.py @@ -0,0 +1,270 @@ +from typing import Dict +from typing import List + +import numpy as np +import pytest + +from etna.datasets import HierarchicalStructure + + +@pytest.fixture +def simple_hierarchical_struct(): + return HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, level_names=["l1", "l2", "l3"] + ) + + +@pytest.fixture +def tailed_hierarchical_struct(): + return HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["f"], "d": ["g"], "a": ["e", "h"]}, + level_names=["l1", "l2", "l3", "l4"], + ) + + +@pytest.fixture +def long_hierarchical_struct(): + return HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a"], "Y": ["b"], "a": ["c"], "b": ["d"]}, + level_names=["l1", "l2", "l3", "l4"], + ) + + +@pytest.mark.parametrize( + "struct, target,source,answer", + ( + ("tailed_hierarchical_struct", "l1", "l2", np.array([[1, 1]])), + ("tailed_hierarchical_struct", "l1", "l3", np.array([[1, 1, 1]])), + ("tailed_hierarchical_struct", "l1", "l4", np.array([[1, 1, 1, 1]])), + ("tailed_hierarchical_struct", "l2", "l3", np.array([[1, 0, 0], [0, 1, 1]])), + ("tailed_hierarchical_struct", "l2", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])), + ("tailed_hierarchical_struct", "l3", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + ("simple_hierarchical_struct", "l1", "l2", np.array([[1, 1]])), + ("simple_hierarchical_struct", "l1", "l3", np.array([[1, 1, 1, 1]])), + ("simple_hierarchical_struct", "l2", "l3", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])), + ("long_hierarchical_struct", "l1", "l2", np.array([[1, 1]])), + ("long_hierarchical_struct", "l1", "l3", np.array([[1, 1]])), + ("long_hierarchical_struct", "l1", "l4", np.array([[1, 1]])), + ("long_hierarchical_struct", "l2", "l3", np.array([[1, 0], [0, 1]])), + ("long_hierarchical_struct", "l2", "l4", np.array([[1, 0], [0, 1]])), + ("long_hierarchical_struct", "l3", "l4", np.array([[1, 0], [0, 1]])), + ), +) +def test_summing_matrix(struct: str, source: str, target: str, answer: np.ndarray, request: pytest.FixtureRequest): + np.testing.assert_array_almost_equal( + answer, request.getfixturevalue(struct).get_summing_matrix(target_level=target, source_level=source).toarray() + ) + + +@pytest.mark.parametrize( + "target,source,error", + ( + ("l0", "l2", "Invalid level name: l0"), + ("l1", "l0", "Invalid level name: l0"), + ("l2", "l1", "Target level must be higher in hierarchy than source level!"), + ), +) +def test_level_transition_errors( + simple_hierarchical_struct: HierarchicalStructure, + target: str, + source: str, + error: str, +): + with pytest.raises(ValueError, match=error): + simple_hierarchical_struct.get_summing_matrix(target_level=target, source_level=source) + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["e", "f"]}, # e f leaves have lower level + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["e"]}, # e has lower level + ), +) +def test_leaves_level_errors(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="All hierarchy tree leaves must be on the same level!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, "total"), + ({"X": ["a", "b"]}, "X"), + ), +) +def test_root_finding(structure: Dict[str, List[str]], answer: str): + assert HierarchicalStructure._find_tree_root(structure) == answer + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, 7), + ({"X": ["a", "b"]}, 3), + ), +) +def test_num_nodes(structure: Dict[str, List[str]], answer: int): + assert HierarchicalStructure._find_num_nodes(structure) == answer + + +@pytest.mark.parametrize( + "level_names,tree_depth,answer", + ( + (None, 3, ["level_0", "level_1", "level_2"]), + (["l1", "l2", "l3", "l4"], 4, ["l1", "l2", "l3", "l4"]), + ), +) +def test_get_level_names(level_names: List[str], tree_depth: int, answer: List[str]): + assert HierarchicalStructure._get_level_names(level_names, tree_depth) == answer + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, [["total"], ["X", "Y"], ["a", "b", "c", "d"]]), + ({"X": ["a", "b"]}, [["X"], ["a", "b"]]), + ), +) +def test_find_hierarchy_levels(structure: Dict[str, List[str]], answer: List[List[str]]): + h = HierarchicalStructure(level_structure=structure) + hierarchy_levels = h._find_hierarchy_levels() + for i, level_segments in enumerate(answer): + assert hierarchy_levels[i] == level_segments + + +@pytest.mark.parametrize( + "structure,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + {"total": 4, "X": 2, "Y": 2, "a": 1, "b": 1, "c": 1, "d": 1}, + ), + ({"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"]}, {"total": 3, "X": 1, "Y": 2, "a": 1, "c": 1, "d": 1}), + ({"X": ["a", "b"]}, {"X": 2, "a": 1, "b": 1}), + ), +) +def test_get_num_reachable_leafs(structure: Dict[str, List[str]], answer: Dict[str, int]): + h = HierarchicalStructure(level_structure=structure) + hierarchy_levels = h._find_hierarchy_levels() + reachable_leafs = h._get_num_reachable_leafs(hierarchy_levels) + assert len(reachable_leafs) == len(answer) + for segment in answer: + assert reachable_leafs[segment] == answer[segment] + + +@pytest.mark.parametrize( + "structure,level_names,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + None, + {"level_0": 0, "level_1": 1, "level_2": 2}, + ), + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + ["l1", "l2", "l3"], + {"l1": 0, "l2": 1, "l3": 2}, + ), + ( + {"X": ["a"]}, + None, + {"level_0": 0, "level_1": 1}, + ), + ), +) +def test_level_to_index(structure: Dict[str, List[str]], level_names: List[str], answer: Dict[str, int]): + h = HierarchicalStructure(level_structure=structure, level_names=level_names) + assert len(h._level_to_index) == len(answer) + for level in answer: + assert h._level_to_index[level] == answer[level] + + +@pytest.mark.parametrize( + "structure,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + { + "total": "level_0", + "X": "level_1", + "Y": "level_1", + "a": "level_2", + "b": "level_2", + "c": "level_2", + "d": "level_2", + }, + ), + ({"X": ["a"]}, {"X": "level_0", "a": "level_1"}), + ), +) +def test_segment_to_level(structure: Dict[str, List[str]], answer: Dict[str, str]): + h = HierarchicalStructure(level_structure=structure) + assert len(h._segment_to_level) == len(answer) + for segment in answer: + assert h._segment_to_level[segment] == answer[segment] + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "total"]}, # loop to root + {"X": ["a", "b"], "Y": ["c", "d"]}, # 2 trees + dict(), # empty list + ), +) +def test_root_finding_errors(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="Invalid tree definition: unable to find root!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["X"]}, # loop + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "Y"]}, # self loop + ), +) +def test_invalid_tree_structure_initialization_fails(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="Invalid tree definition: invalid number of nodes and edges!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure,names,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, None, ["level_0", "level_1", "level_2"]), + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3"], ["l1", "l2", "l3"]), + ), +) +def test_level_names(structure: Dict[str, List[str]], names: List[str], answer: List[str]): + h = HierarchicalStructure(level_structure=structure, level_names=names) + assert h.level_names == answer + + +@pytest.mark.parametrize( + "structure,names", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1"]), + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3", "l4"]), + ), +) +def test_level_names_length_error(structure: Dict[str, List[str]], names: List[str]): + with pytest.raises(ValueError, match="Length of `level_names` must be equal to hierarchy tree depth!"): + HierarchicalStructure(level_structure=structure, level_names=names) + + +@pytest.mark.parametrize( + "level,answer", + (("l1", ["total"]), ("l2", ["X", "Y"]), ("l3", ["a", "b", "c", "d"])), +) +def test_level_segments(simple_hierarchical_struct: HierarchicalStructure, level: str, answer: List[str]): + assert simple_hierarchical_struct.get_level_segments(level) == answer + + +@pytest.mark.parametrize( + "segment,answer", + (("total", "l1"), ("Y", "l2"), ("c", "l3")), +) +def test_segments_level(simple_hierarchical_struct: HierarchicalStructure, segment: str, answer: str): + assert simple_hierarchical_struct.get_segment_level(segment) == answer