diff --git a/CHANGELOG.md b/CHANGELOG.md index f852ea6d5a..64be93e2f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `TreeNode.parent` -- a read-only property for accessing a node's parent https://github.com/Textualize/textual/issues/1397 - Added public `TreeNode` label access via `TreeNode.label` https://github.com/Textualize/textual/issues/1396 +- Added read-only public access to the children of a `TreeNode` via `TreeNode.children` https://github.com/Textualize/textual/issues/1398 ### Changed diff --git a/src/textual/_immutable_sequence_view.py b/src/textual/_immutable_sequence_view.py new file mode 100644 index 0000000000..27283dbc4e --- /dev/null +++ b/src/textual/_immutable_sequence_view.py @@ -0,0 +1,68 @@ +"""Provides an immutable sequence view class.""" + +from __future__ import annotations +from sys import maxsize +from typing import Generic, TypeVar, Iterator, overload, Sequence + +T = TypeVar("T") + + +class ImmutableSequenceView(Generic[T]): + """Class to wrap a sequence of some sort, but not allow modification.""" + + def __init__(self, wrap: Sequence[T]) -> None: + """Initialise the immutable sequence. + + Args: + wrap (Sequence[T]): The sequence being wrapped. + """ + self._wrap = wrap + + @overload + def __getitem__(self, index: int) -> T: + ... + + @overload + def __getitem__(self, index: slice) -> ImmutableSequenceView[T]: + ... + + def __getitem__(self, index: int | slice) -> T | ImmutableSequenceView[T]: + return ( + self._wrap[index] + if isinstance(index, int) + else ImmutableSequenceView[T](self._wrap[index]) + ) + + def __iter__(self) -> Iterator[T]: + return iter(self._wrap) + + def __len__(self) -> int: + return len(self._wrap) + + def __length_hint__(self) -> int: + return len(self) + + def __bool__(self) -> bool: + return bool(self._wrap) + + def __contains__(self, item: T) -> bool: + return item in self._wrap + + def index(self, item: T, start: int = 0, stop: int = maxsize) -> int: + """Return the index of the given item. + + Args: + item (T): The item to find in the sequence. + start (int, optional): Optional start location. + stop (int, optional): Optional stop location. + + Returns: + T: The index of the item in the sequence. + + Raises: + ValueError: If the item is not in the sequence. + """ + return self._wrap.index(item, start, stop) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._wrap) diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index 39b63a0376..866a4d7ceb 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -14,6 +14,7 @@ from .._segment_tools import line_crop, line_pad from .._types import MessageTarget from .._typing import TypeAlias +from .._immutable_sequence_view import ImmutableSequenceView from ..binding import Binding from ..geometry import Region, Size, clamp from ..message import Message @@ -53,6 +54,10 @@ def _get_guide_width(self, guide_depth: int, show_root: bool) -> int: return guides +class TreeNodes(ImmutableSequenceView["TreeNode[TreeDataType]"]): + """An immutable collection of `TreeNode`.""" + + @rich.repr.auto class TreeNode(Generic[TreeDataType]): """An object that represents a "node" in a tree control.""" @@ -91,6 +96,11 @@ def _reset(self) -> None: self._selected_ = False self._updates += 1 + @property + def children(self) -> TreeNodes[TreeDataType]: + """TreeNodes[TreeDataType]: The child nodes of a TreeNode.""" + return TreeNodes(self._children) + @property def line(self) -> int: """int: Get the line number for this node, or -1 if it is not displayed.""" diff --git a/tests/test_immutable_sequence_view.py b/tests/test_immutable_sequence_view.py new file mode 100644 index 0000000000..5af7f5133a --- /dev/null +++ b/tests/test_immutable_sequence_view.py @@ -0,0 +1,68 @@ +import pytest + +from typing import Sequence +from textual._immutable_sequence_view import ImmutableSequenceView + +def wrap(source: Sequence[int]) -> ImmutableSequenceView[int]: + """Wrap a sequence of integers inside an immutable sequence view.""" + return ImmutableSequenceView[int](source) + + +def test_empty_immutable_sequence() -> None: + """An empty immutable sequence should act as anticipated.""" + assert len(wrap([])) == 0 + assert bool(wrap([])) is False + assert list(wrap([])) == [] + + +def test_non_empty_immutable_sequence() -> None: + """A non-empty immutable sequence should act as anticipated.""" + assert len(wrap([0])) == 1 + assert bool(wrap([0])) is True + assert list(wrap([0])) == [0] + + +def test_no_assign_to_immutable_sequence() -> None: + """It should not be possible to assign into an immutable sequence.""" + tester = wrap([1,2,3,4,5]) + with pytest.raises(TypeError): + tester[0] = 23 + with pytest.raises(TypeError): + tester[0:3] = 23 + + +def test_no_del_from_iummutable_sequence() -> None: + """It should not be possible delete an item from an immutable sequence.""" + tester = wrap([1,2,3,4,5]) + with pytest.raises(TypeError): + del tester[0] + + +def test_get_item_from_immutable_sequence() -> None: + """It should be possible to get an item from an immutable sequence.""" + assert wrap(range(10))[0] == 0 + assert wrap(range(10))[-1] == 9 + + +def test_get_slice_from_immutable_sequence() -> None: + """It should be possible to get a slice from an immutable sequence.""" + assert list(wrap(range(10))[0:2]) == [0,1] + assert list(wrap(range(10))[0:-1]) == [0,1,2,3,4,5,6,7,8] + + +def test_immutable_sequence_contains() -> None: + """It should be possible to see if an immutable sequence contains a value.""" + tester = wrap([1,2,3,4,5]) + assert 1 in tester + assert 11 not in tester + + +def test_immutable_sequence_index() -> None: + tester = wrap([1,2,3,4,5]) + assert tester.index(1) == 0 + with pytest.raises(ValueError): + _ = tester.index(11) + + +def test_reverse_immutable_sequence() -> None: + assert list(reversed(wrap([1,2]))) == [2,1] diff --git a/tests/tree/test_tree_node_children.py b/tests/tree/test_tree_node_children.py new file mode 100644 index 0000000000..5b4751745a --- /dev/null +++ b/tests/tree/test_tree_node_children.py @@ -0,0 +1,27 @@ +import pytest +from textual.widgets import Tree, TreeNode + +def label_of(node: TreeNode[None]): + """Get the label of a node as a string""" + return str(node.label) + + +def test_tree_node_children() -> None: + """A node's children property should act like an immutable list.""" + CHILDREN=23 + tree = Tree[None]("Root") + for child in range(CHILDREN): + tree.root.add(str(child)) + assert len(tree.root.children)==CHILDREN + for child in range(CHILDREN): + assert label_of(tree.root.children[child]) == str(child) + assert label_of(tree.root.children[0]) == "0" + assert label_of(tree.root.children[-1]) == str(CHILDREN-1) + assert [label_of(node) for node in tree.root.children] == [str(n) for n in range(CHILDREN)] + assert [label_of(node) for node in tree.root.children[:2]] == [str(n) for n in range(2)] + with pytest.raises(TypeError): + tree.root.children[0] = tree.root.children[1] + with pytest.raises(TypeError): + del tree.root.children[0] + with pytest.raises(TypeError): + del tree.root.children[0:2]