Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTable checks #5059

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 62 additions & 20 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,13 +1013,14 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
Raises:
ColumnDoesNotExist: If there is no column corresponding to the key.
"""
if column_key not in self._column_locations:
raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")

data = self._data
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
yield data[row_key][column_key]
try:
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
row = data[row_key]
yield row[column_key]
except KeyError:
raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")

def get_column_at(self, column_index: int) -> Iterable[CellType]:
"""Get the values from the column at a given index.
Expand Down Expand Up @@ -1051,9 +1052,10 @@ def get_column_index(self, column_key: ColumnKey | str) -> int:
Raises:
ColumnDoesNotExist: If the column key does not exist.
"""
if column_key not in self._column_locations:
column_index = self._column_locations.get(column_key)
if column_index is None:
raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}")
return self._column_locations.get(column_key)
return column_index

def _clear_caches(self) -> None:
self._row_render_cache.clear()
Expand All @@ -1075,7 +1077,12 @@ def get_row_height(self, row_key: RowKey) -> int:
"""
if row_key is self._header_row_key:
return self.header_height
return self.rows[row_key].height

row = self.rows.get(row_key)
if row is None:
raise RowDoesNotExist(f"Row key {row_key!r} is not valid.")

return row.height

def notify_style_update(self) -> None:
self._row_render_cache.clear()
Expand Down Expand Up @@ -1466,14 +1473,19 @@ def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None:
self._total_row_height + header_height,
)

def _get_cell_region(self, coordinate: Coordinate) -> Region:
def _get_cell_region(self, coordinate: Coordinate) -> Region | None:
"""Get the region of the cell at the given spatial coordinate."""
if not self.is_valid_coordinate(coordinate):
return Region(0, 0, 0, 0)
return None

row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index)
row = self.rows[row_key]
if row_key is None:
return None

row = self.rows.get(row_key)
if row is None:
return None

# The x-coordinate of a cell is the sum of widths of the data cells to the left
# plus the width of the render width of the longest row label.
Expand All @@ -1485,6 +1497,8 @@ def _get_cell_region(self, coordinate: Coordinate) -> Region:
+ self._row_label_column_width
)
column_key = self._column_locations.get_key(column_index)
if column_key is None:
return None
width = self.columns[column_key].get_render_width(self)
height = row.height
y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index])
Expand All @@ -1493,28 +1507,36 @@ def _get_cell_region(self, coordinate: Coordinate) -> Region:
cell_region = Region(x, y, width, height)
return cell_region

def _get_row_region(self, row_index: int) -> Region:
def _get_row_region(self, row_index: int) -> Region | None:
"""Get the region of the row at the given index."""
if not self.is_valid_row_index(row_index):
return Region(0, 0, 0, 0)
return None

rows = self.rows
row_key = self._row_locations.get_key(row_index)
row = rows[row_key]
if row_key is None:
return None

row = rows.get(row_key)
if row is None:
return None

row_width = (
sum(column.get_render_width(self) for column in self.columns.values())
+ self._row_label_column_width
)
y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index])

if self.show_header:
y += self.header_height

row_region = Region(0, y, row_width, row.height)
return row_region

def _get_column_region(self, column_index: int) -> Region:
def _get_column_region(self, column_index: int) -> Region | None:
"""Get the region of the column at the given index."""
if not self.is_valid_column_index(column_index):
return Region(0, 0, 0, 0)
return None

columns = self.columns
x = (
Expand All @@ -1525,6 +1547,9 @@ def _get_column_region(self, column_index: int) -> Region:
+ self._row_label_column_width
)
column_key = self._column_locations.get_key(column_index)
if column_key is None:
return None

width = columns[column_key].get_render_width(self)
header_height = self.header_height if self.show_header else 0
height = self._total_row_height + header_height
Expand Down Expand Up @@ -1737,7 +1762,10 @@ def remove_row(self, row_key: RowKey | str) -> None:
self.check_idle()

index_to_delete = self._row_locations.get(row_key)
new_row_locations = TwoWayDict({})
if index_to_delete is None:
raise RowDoesNotExist(f"Row key {row_key!r} is not valid.")

new_row_locations = TwoWayDict[RowKey, int]({})
for row_location_key in self._row_locations:
row_index = self._row_locations.get(row_location_key)
if row_index > index_to_delete:
Expand Down Expand Up @@ -1833,6 +1861,8 @@ def refresh_coordinate(self, coordinate: Coordinate) -> Self:
if not self.is_valid_coordinate(coordinate):
return self
region = self._get_cell_region(coordinate)
if region is None:
return self
self._refresh_region(region)
return self

Expand All @@ -1849,6 +1879,8 @@ def refresh_row(self, row_index: int) -> Self:
return self

region = self._get_row_region(row_index)
if region is None:
return self
self._refresh_region(region)
return self

Expand All @@ -1865,6 +1897,8 @@ def refresh_column(self, column_index: int) -> Self:
return self

region = self._get_column_region(column_index)
if region is None:
return self
self._refresh_region(region)
return self

Expand Down Expand Up @@ -2535,13 +2569,21 @@ def _scroll_cursor_into_view(self, animate: bool = False) -> None:
top, _, _, left = fixed_offset

if self.cursor_type == "row":
x, y, width, height = self._get_row_region(self.cursor_row)
row_region = self._get_row_region(self.cursor_row)
if row_region is None:
return
x, y, width, height = row_region
region = Region(int(self.scroll_x) + left, y, width - left, height)
elif self.cursor_type == "column":
x, y, width, height = self._get_column_region(self.cursor_column)
column_region = self._get_column_region(self.cursor_column)
if column_region is None:
return
x, y, width, height = column_region
region = Region(x, int(self.scroll_y) + top, width, height - top)
else:
region = self._get_cell_region(self.cursor_coordinate)
if region is None:
return

self.scroll_to_region(region, animate=animate, spacing=fixed_offset, force=True)

Expand Down
Loading