Skip to content

Commit

Permalink
Update DataTree repr to indicate inheritance (pydata#9470)
Browse files Browse the repository at this point in the history
* Update DataTree repr to indicate inheritance

Fixes pydata#9463

* fix whitespace

* add more repr tests, fix failure

* fix failure on windows

* fix repr for inherited dimensions
  • Loading branch information
shoyer authored Sep 10, 2024
1 parent 9735216 commit a6bacfe
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 49 deletions.
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Datatree:
def setup(self):
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
self.d_few = {"run1": run1}
self.d_many = {f"run{i}": run1.copy() for i in range(100)}
self.d_many = {f"run{i}": xr.Dataset({"a": 1}) for i in range(100)}

def time_from_dict_few(self):
DataTree.from_dict(self.d_few)
Expand Down
140 changes: 109 additions & 31 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import contextlib
import functools
import math
from collections import defaultdict
from collections.abc import Collection, Hashable, Sequence
from collections import ChainMap, defaultdict
from collections.abc import Collection, Hashable, Mapping, Sequence
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
from textwrap import dedent
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
Expand All @@ -29,6 +29,7 @@
if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree
from xarray.core.variable import Variable

UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")

Expand Down Expand Up @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width):

def summarize_variable(
name: Hashable,
var,
var: Variable,
col_width: int,
max_width: int | None = None,
is_index: bool = False,
Expand Down Expand Up @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
)


def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None):
coords = _inherited_vars(node._coord_variables)
if col_width is None:
col_width = _calculate_col_width(coords)
return _mapping_repr(
coords,
title="Inherited coordinates",
summarizer=summarize_variable,
expand_option_name="display_expand_coords",
col_width=col_width,
indexes=node._indexes,
max_rows=max_rows,
)


def inline_index_repr(index: pd.Index, max_width=None):
if hasattr(index, "_repr_inline_"):
repr_ = index._repr_inline_(max_width=max_width)
Expand Down Expand Up @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool):
}


def indexes_repr(indexes, max_rows: int | None = None) -> str:
def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str:
col_width = _calculate_col_width(chain.from_iterable(indexes))

return _mapping_repr(
indexes,
"Indexes",
title,
summarize_index,
"display_expand_indexes",
col_width=col_width,
Expand Down Expand Up @@ -571,8 +587,10 @@ def _element_formatter(
return "".join(out)


def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str:
elements = [f"{k}: {v}" for k, v in obj.sizes.items()]
def dim_summary_limited(
sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None
) -> str:
elements = [f"{k}: {v}" for k, v in sizes.items()]
return _element_formatter(elements, col_width, max_rows)


Expand Down Expand Up @@ -676,7 +694,7 @@ def array_repr(arr):
data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"])

start = f"<xarray.{type(arr).__name__} {name_str}"
dims = dim_summary_limited(arr, col_width=len(start) + 1, max_rows=max_rows)
dims = dim_summary_limited(arr.sizes, col_width=len(start) + 1, max_rows=max_rows)
nbytes_str = render_human_readable_nbytes(arr.nbytes)
summary = [
f"{start}({dims})> Size: {nbytes_str}",
Expand Down Expand Up @@ -721,7 +739,9 @@ def dataset_repr(ds):
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -756,7 +776,9 @@ def dims_and_coords_repr(ds) -> str:
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
dims_values = dim_summary_limited(
ds.sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
Expand Down Expand Up @@ -1050,39 +1072,95 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
return "\n".join(summary)


def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
# TODO: change this to inherited=False, in order to clarify what is
# inherited? https://github.com/pydata/xarray/issues/9463
node_view = node._to_dataset_view(rebuild_dims=False, inherited=True)
ds_info = "\n" + repr(node_view)
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
def _inherited_vars(mapping: ChainMap) -> dict:
return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]}


def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary = [f"Group: {node.path}"]

col_width = _calculate_col_width(node.variables)
max_rows = OPTIONS["display_max_rows"]

inherited_coords = _inherited_vars(node._coord_variables)

# Only show dimensions if also showing a variable or coordinates section.
show_dims = (
node._node_coord_variables
or (show_inherited and inherited_coords)
or node._data_variables
)

dim_sizes = node.sizes if show_inherited else node._node_dims

if show_dims:
# Includes inherited dimensions.
dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(
dim_sizes, col_width=col_width + 1, max_rows=max_rows
)
summary.append(f"{dims_start}({dims_values})")

if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
inherited_coords_repr(node, col_width=col_width, max_rows=max_rows)
)

if show_dims:
unindexed_dims_str = unindexed_dims_repr(
dim_sizes, node.coords, max_rows=max_rows
)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

if node._data_variables:
summary.append(
data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows)
)

# TODO: only show indexes defined at this node, with a separate section for
# inherited indexes (if show_inherited=True)
display_default_indexes = _get_boolean_with_default(
"display_default_indexes", False
)
xindexes = filter_nondefault_indexes(
_get_indexes_dict(node.xindexes), not display_default_indexes
)
if xindexes:
summary.append(indexes_repr(xindexes, max_rows=max_rows))

def datatree_repr(dt: DataTree):
if node.attrs:
summary.append(attrs_repr(node.attrs, max_rows=max_rows))

return "\n".join(summary)


def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True
for pre, fill, node in renderer:
node_repr = _single_node_repr(node)
node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

raw_repr_lines = node_repr.splitlines()

node_line = f"{pre}{node_repr.splitlines()[0]}"
node_line = f"{pre}{raw_repr_lines[0]}"
lines.append(node_line)

if node.has_data or node.has_attrs:
ds_repr = node_repr.splitlines()[2:]
for line in ds_repr:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")
for line in raw_repr_lines[1:]:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")

return "\n".join(lines)

Expand Down
Loading

0 comments on commit a6bacfe

Please sign in to comment.