Skip to content

Commit

Permalink
Add RichProgressBar
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 16, 2024
1 parent d429d75 commit fa776cf
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
112 changes: 112 additions & 0 deletions cubed/extensions/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging
import sys
from contextlib import contextmanager

from rich.console import RenderableType
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
Task,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class RichProgressBar(Callback):
"""Rich progress bar for a computation."""

def on_compute_start(self, event):
# Set the pulse_style to the background colour to disable pulsing,
# since Rich will pulse all non-started bars.
logger_aware_progress = LoggerAwareProgress(
SpinnerWhenRunningColumn(),
TextColumn("[progress.description]{task.description}"),
LeftJustifiedMofNCompleteColumn(),
BarColumn(bar_width=None, pulse_style="bar.back"),
TaskProgressColumn(
text_format="[progress.percentage]{task.percentage:>3.1f}%"
),
TimeElapsedColumn(),
logger=logging.getLogger(),
)
progress = logger_aware_progress.__enter__()

progress_tasks = {}
for name, node in visit_nodes(event.dag, event.resume):
num_tasks = node["primitive_op"].num_tasks
progress_task = progress.add_task(f"{name}", start=False, total=num_tasks)
progress_tasks[name] = progress_task

self.logger_aware_progress = logger_aware_progress
self.progress = progress
self.progress_tasks = progress_tasks

def on_compute_end(self, event):
self.logger_aware_progress.__exit__(None, None, None)

def on_operation_start(self, event):
self.progress.start_task(self.progress_tasks[event.name])

def on_task_end(self, event):
self.progress.update(self.progress_tasks[event.name], advance=event.num_tasks)


class SpinnerWhenRunningColumn(SpinnerColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Override so spinner is not shown when bar has not yet started
def render(self, task: "Task") -> RenderableType:
text = (
self.finished_text
if not task.started or task.finished
else self.spinner.render(task.get_time())
)
return text


class LeftJustifiedMofNCompleteColumn(MofNCompleteColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def render(self, task: "Task") -> Text:
"""Show completed/total."""
completed = int(task.completed)
total = int(task.total) if task.total is not None else "?"
total_width = len(str(total))
return Text(
f"{completed}{self.separator}{total}".ljust(total_width + 1 + total_width),
style="progress.download",
)


# Based on CustomProgress from https://github.com/Textualize/rich/discussions/1578
@contextmanager
def LoggerAwareProgress(*args, **kwargs):
"""Wrapper around rich.progress.Progress to manage logging output to stderr."""
try:
__logger = kwargs.pop("logger", None)
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]

with Progress(*args, **kwargs) as progress:
for handler in streamhandlers:
__prior_stderr = handler.stream
handler.setStream(sys.stderr)

yield progress

finally:
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]
for handler in streamhandlers:
handler.setStream(__prior_stderr)
14 changes: 14 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.extensions.history import HistoryCallback
from cubed.extensions.rich import RichProgressBar
from cubed.extensions.timeline import TimelineVisualizationCallback
from cubed.extensions.tqdm import TqdmProgressBar
from cubed.primitive.blockwise import apply_blockwise
Expand Down Expand Up @@ -97,6 +98,19 @@ def test_callbacks(spec, executor):
assert task_counter.value == num_created_arrays + 4


def test_rich_progress_bar(spec, executor):
# test indirectly by checking it doesn't cause a failure
progress = RichProgressBar()

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, callbacks=[progress]),
np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]),
)


@pytest.mark.cloud
def test_callbacks_modal(spec, modal_executor):
task_counter = TaskCounter(check_timestamps=False)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ diagnostics = [
"pydot",
"pandas",
"matplotlib",
"rich",
"seaborn",
]
beam = ["apache-beam", "gcsfs"]
Expand Down

0 comments on commit fa776cf

Please sign in to comment.