Skip to content

HierarchicalStructure class #1044

Merged
merged 12 commits into from
Dec 21, 2022
1 change: 1 addition & 0 deletions etna/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from etna.datasets.datasets_generation import generate_const_df
from etna.datasets.datasets_generation import generate_from_patterns_df
from etna.datasets.datasets_generation import generate_periodic_df
from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import duplicate_data
from etna.datasets.utils import set_columns_wide
187 changes: 187 additions & 0 deletions etna/datasets/hierarchical_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from collections import defaultdict
from itertools import chain
from queue import Queue
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Union

import scipy
from scipy.sparse import lil_matrix

from etna.core import BaseMixin


class HierarchicalStructure(BaseMixin):
"""Represents hierarchical structure for provided hierarchical tree."""
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, level_structure: Dict[str, List[str]], level_names: Optional[List[str]] = None):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Init HierarchicalStructure.

Parameters
----------
level_structure:
Adjacency list describing the structure of the hierarchy tree (i.e. {"total":["X", "Y"], "X":["a", "b"], "Y":["c", "d"]}).
level_names:
Names of levels in the hierarchy in the order from top to bottom (i.e. ["total", "category", "product"]).
If None is passed, level names are generated automatically with structure "level_<level_index>".
"""
self._hierarchy_root: Union[str, None] = None
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
self._hierarchy_interm_nodes: Set[str] = set()
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
self._hierarchy_leaves: Set[str] = set()

self.level_structure: Dict[str, List[str]] = level_structure
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

self._find_graph_structure(level_structure)
hierarchy_levels = self._find_hierarchy_levels(level_structure)
tree_depth = len(hierarchy_levels)
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

if level_names is None:
level_names = [f"level_{i}" for i in range(tree_depth)]

if len(level_names) != tree_depth:
raise ValueError("Length of `level_names` must be equal to hierarchy tree depth!")

self.level_names: List[str] = level_names
self._level_series: Dict[str, List[str]] = {level_names[i]: hierarchy_levels[i] for i in range(tree_depth)}
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
self._level_to_index: Dict[str, int] = {level_names[i]: i for i in range(tree_depth)}

self._sub_segment_size_map: Dict[str, int] = {k: len(v) for k, v in level_structure.items()}

self._segment_subtree_size_map: Dict[str, int] = self._get_subtree_sizes(hierarchy_levels)

self._segment_to_level: Dict[str, str] = {
segment: level for level in self._level_series for segment in self._level_series[level]
}

def _find_graph_structure(self, adj_list: Dict[str, List[str]]):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""Find hierarchy top level (root of tree)."""
children = set(chain(*adj_list.values()))
parents = set(adj_list.keys())

tree_roots = parents.difference(children)
if len(tree_roots) != 1:
raise ValueError("Invalid tree definition: unable to find root!")

self._hierarchy_interm_nodes = parents & children
self._hierarchy_leaves = children.difference(parents)

tree_root = tree_roots.pop()
self._hierarchy_root = tree_root

def _find_hierarchy_levels(self, hierarchy_structure: Dict[str, List[str]]):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""Traverse hierarchy tree to group segments into levels."""
nodes: Set[str] = self._hierarchy_interm_nodes | self._hierarchy_leaves
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
nodes.add(str(self._hierarchy_root))

num_edges = sum(map(len, hierarchy_structure.values()))

num_nodes = len(nodes)
if num_edges != num_nodes - 1:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Invalid tree definition: invalid number of nodes and edges!")

leaves_level = None
node_levels = []
seen_nodes = {self._hierarchy_root}
queue: Queue = Queue()
queue.put((self._hierarchy_root, 0))
while not queue.empty():
node, level = queue.get()
node_levels.append((level, node))
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
child_nodes = hierarchy_structure.get(node, [])

if len(child_nodes) == 0:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
if leaves_level is not None and level != leaves_level:
raise ValueError("All hierarchy tree leaves must be on the same level!")
else:
leaves_level = level

for adj_node in child_nodes:
if adj_node not in seen_nodes:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
queue.put((adj_node, level + 1))
seen_nodes.add(adj_node)

if len(seen_nodes) != num_nodes:
raise ValueError("Invalid tree definition: disconnected graph!")

levels = defaultdict(list)
for level, node in node_levels:
levels[level].append(node)

return levels

def _get_subtree_sizes(self, hierarchy_levels: Dict[int, List[str]]) -> Dict[str, int]:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""Compute subtree size for each node."""
subtree_size_map = dict()
for _, node_list in sorted(hierarchy_levels.items(), key=lambda x: x[0], reverse=True):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
for node in node_list:
subtree_size = 0
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
if node not in self.level_structure:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
subtree_size_map[node] = 1
continue

for child_node in self.level_structure.get(node, []):
subtree_size += subtree_size_map[child_node]

subtree_size_map[node] = subtree_size

return subtree_size_map

def get_summing_matrix(self, target_level: str, source_level: str) -> scipy.sparse.base.spmatrix:
"""
Get summing matrix for transition from source level to target level.
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
target_level:
Name of target level.
source_level:
Name of source level.

Returns
-------
:
transition matrix from source level to target level
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

"""
try:
target_idx = self._level_to_index[target_level]
source_idx = self._level_to_index[source_level]
except KeyError as e:
raise ValueError("Invalid level name: " + e.args[0])

if target_idx >= source_idx:
raise ValueError("Target level must be higher in hierarchy than source level!")

target_level_list = self.get_level_segments(target_level)
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
source_level_list = self.get_level_segments(source_level)
summing_matrix = lil_matrix((len(target_level_list), len(source_level_list)))

j = 0
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
for i, segment in enumerate(target_level_list):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
sub_segment_size = self._segment_subtree_size_map[segment]
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

while sub_segment_size > 0:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
source_segment = source_level_list[j]
sub_segment_size -= self._segment_subtree_size_map[source_segment]
summing_matrix[i, j] = 1
j += 1

summing_matrix.tocsr()
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved

return summing_matrix

def get_level_segments(self, level_name: str) -> List[str]:
"""Get all segments from particular level."""
try:
return self._level_series[level_name]
except KeyError as e:
raise ValueError("Invalid level name: " + e.args[0])

def get_segment_level(self, segment: str) -> Union[str, None]:
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""Get level name for provided segment."""
try:
return self._segment_to_level[segment]
except KeyError:
return None
176 changes: 176 additions & 0 deletions tests/test_datasets/test_hierarchical_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Dict
from typing import List

import numpy as np
import pytest
import scipy.sparse

from etna.datasets import HierarchicalStructure


@pytest.fixture
def simple_hierarchical_struct():
return HierarchicalStructure(
level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, level_names=["l1", "l2", "l3"]
)


@pytest.fixture
def tailed_hierarchical_struct():
return HierarchicalStructure(
level_structure={"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["f"], "d": ["g"], "a": ["e", "h"]},
level_names=["l1", "l2", "l3", "l4"],
)


@pytest.fixture
def long_hierarchical_struct():
return HierarchicalStructure(
level_structure={"total": ["X", "Y"], "X": ["a"], "Y": ["b"], "a": ["c"], "b": ["d"]},
level_names=["l1", "l2", "l3", "l4"],
)


@pytest.mark.parametrize(
"target,source",
(
("l1", "l2"),
("l2", "l3"),
),
)
def test_get_summing_matrix_output(simple_hierarchical_struct: HierarchicalStructure, source: str, target: str):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
output = simple_hierarchical_struct.get_summing_matrix(target_level=target, source_level=source)
assert isinstance(output, scipy.sparse.base.spmatrix)


@pytest.mark.parametrize(
"struct, target,source,answer",
(
("tailed_hierarchical_struct", "l1", "l2", np.array([[1, 1]])),
("tailed_hierarchical_struct", "l1", "l3", np.array([[1, 1, 1]])),
("tailed_hierarchical_struct", "l1", "l4", np.array([[1, 1, 1, 1]])),
("tailed_hierarchical_struct", "l2", "l3", np.array([[1, 0, 0], [0, 1, 1]])),
("tailed_hierarchical_struct", "l2", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])),
("tailed_hierarchical_struct", "l3", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
("simple_hierarchical_struct", "l1", "l2", np.array([[1, 1]])),
("simple_hierarchical_struct", "l1", "l3", np.array([[1, 1, 1, 1]])),
("simple_hierarchical_struct", "l2", "l3", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])),
("long_hierarchical_struct", "l1", "l2", np.array([[1, 1]])),
("long_hierarchical_struct", "l1", "l3", np.array([[1, 1]])),
("long_hierarchical_struct", "l1", "l4", np.array([[1, 1]])),
("long_hierarchical_struct", "l2", "l3", np.array([[1, 0], [0, 1]])),
("long_hierarchical_struct", "l2", "l4", np.array([[1, 0], [0, 1]])),
("long_hierarchical_struct", "l3", "l4", np.array([[1, 0], [0, 1]])),
),
)
def test_summing_matrix(struct: str, source: str, target: str, answer: np.ndarray, request: pytest.FixtureRequest):
np.testing.assert_array_almost_equal(
answer, request.getfixturevalue(struct).get_summing_matrix(target_level=target, source_level=source).toarray()
)


@pytest.mark.parametrize(
"target,source,error",
(
("l0", "l2", "Invalid level name: l0"),
("l1", "l0", "Invalid level name: l0"),
("l2", "l1", "Target level must be higher in hierarchy than source level!"),
),
)
def test_level_transition_errors(
simple_hierarchical_struct: HierarchicalStructure,
target: str,
source: str,
error: str,
):
with pytest.raises(ValueError, match=error):
simple_hierarchical_struct.get_summing_matrix(target_level=target, source_level=source)


@pytest.mark.parametrize(
"structure,answer",
(
({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, "total"),
({"X": ["a", "b"]}, "X"),
),
)
def test_root_finding(structure: Dict[str, List[str]], answer: str):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
h = HierarchicalStructure(level_structure=structure)
assert h._hierarchy_root == answer


@pytest.mark.parametrize(
"structure",
(
{"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "total"]}, # loop to root
{"X": ["a", "b"], "Y": ["c", "d"]}, # 2 trees
dict(), # empty list
),
)
def test_root_finding_errors(structure: Dict[str, List[str]]):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="Invalid tree definition: unable to find root!"):
HierarchicalStructure(level_structure=structure)


@pytest.mark.parametrize(
"structure",
(
{"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["X"]}, # loop
{"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "Y"]}, # self loop
),
)
def test_invalid_tree_structure_initialization_fails(structure: Dict[str, List[str]]):
with pytest.raises(ValueError, match="Invalid tree definition: invalid number of nodes and edges!"):
HierarchicalStructure(level_structure=structure)


@pytest.mark.parametrize(
"structure,names,answer",
(
({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, None, ["level_0", "level_1", "level_2"]),
({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3"], ["l1", "l2", "l3"]),
),
)
def test_level_names(structure: Dict[str, List[str]], names: List[str], answer: List[str]):
h = HierarchicalStructure(level_structure=structure, level_names=names)
assert h.level_names == answer


@pytest.mark.parametrize(
"structure,names",
(
({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1"]),
({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3", "l4"]),
),
)
def test_level_names_length_error(structure: Dict[str, List[str]], names: List[str]):
with pytest.raises(ValueError, match="Length of `level_names` must be equal to hierarchy tree depth!"):
HierarchicalStructure(level_structure=structure, level_names=names)


@pytest.mark.parametrize(
"level,answer",
(("l1", ["total"]), ("l2", ["X", "Y"]), ("l3", ["a", "b", "c", "d"])),
)
def test_level_segments(simple_hierarchical_struct: HierarchicalStructure, level: str, answer: List[str]):
assert simple_hierarchical_struct.get_level_segments(level) == answer


@pytest.mark.parametrize(
"segment,answer",
(("total", "l1"), ("Y", "l2"), ("c", "l3")),
)
def test_segments_level(simple_hierarchical_struct: HierarchicalStructure, segment: str, answer: str):
assert simple_hierarchical_struct.get_segment_level(segment) == answer


@pytest.mark.parametrize(
"structure",
(
{"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["e", "f"]}, # e f leaves have lower level
{"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["e"]}, # e has lower level
),
)
def test_leaves_level_errors(structure: Dict[str, List[str]]):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="All hierarchy tree leaves must be on the same level!"):
HierarchicalStructure(level_structure=structure)