From ffe12a5f103b9f06d728429fc0d930b76523726f Mon Sep 17 00:00:00 2001 From: Peyton Murray Date: Tue, 30 Aug 2022 15:09:40 -0700 Subject: [PATCH] [Tune] Add rich output for ray tune progress updates in notebooks (#26263) These changes are part of a series intended to improve integration with notebooks. This PR modifies the tune progress status shown to the user if tuning is run from a notebook. Previously, part of the trial progress was reported in an HTML table before; now, all progress is displayed in an organized HTML template. Signed-off-by: pdmurray --- python/ray/tune/progress_reporter.py | 350 ++++++++++++++++-- .../widgets/templates/trial_progress.html.j2 | 17 + .../ray/widgets/templates/tune_status.html.j2 | 49 +++ .../templates/tune_status_messages.html.j2 | 25 ++ 4 files changed, 400 insertions(+), 41 deletions(-) create mode 100644 python/ray/widgets/templates/trial_progress.html.j2 create mode 100644 python/ray/widgets/templates/tune_status.html.j2 create mode 100644 python/ray/widgets/templates/tune_status_messages.html.j2 diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index b039f849c6ca..757eae44bdcc 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -3,11 +3,12 @@ import collections import datetime import numbers + import os import sys import time import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from ray._private.dict import flatten_dict @@ -36,6 +37,8 @@ from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.queue import Queue +from ray.widgets import Template + try: from collections.abc import Mapping, MutableMapping except ImportError: @@ -544,16 +547,14 @@ def __init__( self.display("") # initialize empty display to update later def report(self, trials: List[Trial], done: bool, *sys_info: Dict): - progress_str = self._progress_str( - trials, done, *sys_info, fmt="html", delim="
" - ) + progress = self._progress_html(trials, done, *sys_info) if self.output_queue is not None: # If an output queue is set, send string - self.output_queue.put(progress_str) + self.output_queue.put(progress) else: # Else, output directly - self.display(progress_str) + self.display(progress) def display(self, string: str) -> None: from IPython.display import HTML, clear_output, display @@ -565,6 +566,71 @@ def display(self, string: str) -> None: else: self._display_handle.update(HTML(string)) + def _progress_html(self, trials: List[Trial], done: bool, *sys_info) -> str: + """Generate an HTML-formatted progress update. + + Args: + trials: List of trials for which progress should be + displayed + done: True if the trials are finished, False otherwise + *sys_info: System information to be displayed + + Returns: + Progress update to be rendered in a notebook, including HTML + tables and formatted error messages. Includes + - Duration of the tune job + - Memory consumption + - Trial progress table, with information about each experiment + """ + if not self._metrics_override: + user_metrics = self._infer_user_metrics(trials, self._infer_limit) + self._metric_columns.update(user_metrics) + + current_time, running_for = _get_time_str(self._start_time, time.time()) + used_gb, total_gb, memory_message = _get_memory_usage() + + status_table = tabulate( + [ + ("Current time:", current_time), + ("Running for:", running_for), + ("Memory:", f"{used_gb}/{total_gb} GiB"), + ], + tablefmt="html", + ) + trial_progress_data = _trial_progress_table( + trials=trials, + metric_columns=self._metric_columns, + parameter_columns=self._parameter_columns, + fmt="html", + max_rows=None if done else self._max_progress_rows, + metric=self._metric, + mode=self._mode, + sort_by_metric=self._sort_by_metric, + max_column_length=self._max_column_length, + ) + + trial_progress = trial_progress_data[0] + trial_progress_messages = trial_progress_data[1:] + trial_errors = _trial_errors_str( + trials, fmt="html", max_rows=None if done else self._max_error_rows + ) + + if any([memory_message, trial_progress_messages, trial_errors]): + msg = Template("tune_status_messages.html.j2").render( + memory_message=memory_message, + trial_progress_messages=trial_progress_messages, + trial_errors=trial_errors, + ) + else: + msg = None + + return Template("tune_status.html.j2").render( + status_table=status_table, + sys_info_message=_generate_sys_info_str(*sys_info), + trial_progress=trial_progress, + messages=msg, + ) + @PublicAPI class CLIReporter(TuneReporterBase): @@ -641,7 +707,14 @@ def report(self, trials: List[Trial], done: bool, *sys_info: Dict): print(self._progress_str(trials, done, *sys_info)) -def _memory_debug_str(): +def _get_memory_usage() -> Tuple[float, float, Optional[str]]: + """Get the current memory consumption. + + Returns: + Memory used, memory available, and optionally a warning + message to be shown to the user when memory consumption is higher + than 90% or if `psutil` is not installed + """ try: import ray # noqa F401 @@ -650,7 +723,7 @@ def _memory_debug_str(): total_gb = psutil.virtual_memory().total / (1024 ** 3) used_gb = total_gb - psutil.virtual_memory().available / (1024 ** 3) if used_gb > total_gb * 0.9: - warn = ( + message = ( ": ***LOW MEMORY*** less than 10% of the memory on " "this node is available for use. This can cause " "unexpected crashes. Consider " @@ -659,15 +732,41 @@ def _memory_debug_str(): "`object_store_memory` when calling `ray.init`." ) else: - warn = "" - return "Memory usage on this node: {}/{} GiB{}".format( - round(used_gb, 1), round(total_gb, 1), warn - ) + message = None + + return round(used_gb, 1), round(total_gb, 1), message except ImportError: - return "Unknown memory usage. Please run `pip install psutil` to resolve)" + return ( + np.nan, + np.nan, + "Unknown memory usage. Please run `pip install psutil` to resolve", + ) -def _time_passed_str(start_time: float, current_time: float): +def _memory_debug_str() -> str: + """Generate a message to be shown to the user showing memory consumption. + + Returns: + String to be shown to the user with formatted memory consumption + stats. + """ + used_gb, total_gb, message = _get_memory_usage() + if np.isnan(used_gb): + return message + else: + return f"Memory usage on this node: {used_gb}/{total_gb} GiB{message}" + + +def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]: + """Get strings representing the current and elapsed time. + + Args: + start_time: POSIX timestamp of the start of the tune run + current_time: POSIX timestamp giving the current time + + Returns: + Current time and elapsed time for the current run + """ current_time_dt = datetime.datetime.fromtimestamp(current_time) start_time_dt = datetime.datetime.fromtimestamp(start_time) delta: datetime.timedelta = current_time_dt - start_time_dt @@ -690,10 +789,22 @@ def _time_passed_str(start_time: float, current_time: float): running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}" - return ( - f"Current time: {current_time_dt:%Y-%m-%d %H:%M:%S} " - f"(running for {running_for_str})" - ) + return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str + + +def _time_passed_str(start_time: float, current_time: float) -> str: + """Generate a message describing the current and elapsed time in the run. + + Args: + start_time: POSIX timestamp of the start of the tune run + current_time: POSIX timestamp giving the current time + + Returns: + Message with the current and elapsed time for the current tune run, + formatted to be displayed to the user + """ + current_time_str, running_for_str = _get_time_str(start_time, current_time) + return f"Current time: {current_time_str} " f"(running for {running_for_str})" def _get_trials_by_state(trials: List[Trial]): @@ -819,18 +930,38 @@ def _max_len(value: Any, max_len: int = 20, add_addr: bool = False) -> Any: return result -def _trial_progress_table( +def _get_progress_table_data( trials: List[Trial], metric_columns: Union[List[str], Dict[str, str]], parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, - fmt: str = "psql", max_rows: Optional[int] = None, metric: Optional[str] = None, mode: Optional[str] = None, sort_by_metric: bool = False, max_column_length: int = 20, -): - messages = [] +) -> Tuple[List, List[str], Tuple[bool, str]]: + """Generate a table showing the current progress of tuning trials. + + Args: + trials: List of trials for which progress is to be shown. + metric_columns: Metrics to be displayed in the table. + parameter_columns: List of parameters to be included in the data + max_rows: Maximum number of rows to show. If there's overflow, a + message will be shown to the user indicating that some rows + are not displayed + metric: Metric which is being tuned + mode: Sort the table in descending order if mode is "max"; + ascending otherwise + sort_by_metric: If true, the table will be sorted by the metric + max_column_length: Max number of characters in each column + + Returns: + - Trial data + - List of column names + - Overflow tuple: + - boolean indicating whether the table has rows which are hidden + - string with info about the overflowing rows + """ num_trials = len(trials) trials_by_state = _get_trials_by_state(trials) @@ -928,17 +1059,68 @@ def _trial_progress_table( + formatted_parameter_columns + formatted_metric_columns ) - # Tabulate. - messages.append( - tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False) + + return trial_table, columns, (overflow, overflow_str) + + +def _trial_progress_table( + trials: List[Trial], + metric_columns: Union[List[str], Dict[str, str]], + parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, + fmt: str = "psql", + max_rows: Optional[int] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sort_by_metric: bool = False, + max_column_length: int = 20, +) -> List[str]: + """Generate a list of trial progress table messages. + + Args: + trials: List of trials for which progress is to be shown. + metric_columns: Metrics to be displayed in the table. + parameter_columns: List of parameters to be included in the data + fmt: Format of the table; passed to tabulate as the fmtstr argument + max_rows: Maximum number of rows to show. If there's overflow, a + message will be shown to the user indicating that some rows + are not displayed + metric: Metric which is being tuned + mode: Sort the table in descenting order if mode is "max"; + ascending otherwise + sort_by_metric: If true, the table will be sorted by the metric + max_column_length: Max number of characters in each column + + Returns: + Messages to be shown to the user containing progress tables + """ + data, columns, (overflow, overflow_str) = _get_progress_table_data( + trials, + metric_columns, + parameter_columns, + max_rows, + metric, + mode, + sort_by_metric, + max_column_length, ) + messages = [tabulate(data, headers=columns, tablefmt=fmt, showindex=False)] if overflow: - messages.append( - "... {} more trials not shown ({})".format(overflow, overflow_str) - ) + messages.append(f"... {overflow} more trials not shown ({overflow_str})") return messages +def _generate_sys_info_str(*sys_info) -> str: + """Format system info into a string. + *sys_info: System info strings to be included. + + Returns: + Formatted string containing system information. + """ + if sys_info: + return "
".join(sys_info).replace("\n", "
") + return "" + + def _trial_errors_str( trials: List[Trial], fmt: str = "psql", max_rows: Optional[int] = None ): @@ -1113,6 +1295,8 @@ def __init__( # Only use progress metrics if at least two metrics are in there if self._metric and self._progress_metrics: self._progress_metrics.add(self._metric) + self._last_result = {} + self._display_handle = None def on_trial_result( self, @@ -1147,30 +1331,45 @@ def on_trial_complete( def log_result(self, trial: "Trial", result: Dict, error: bool = False): done = result.get("done", False) is True last_print = self._last_print[trial] + should_print = done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL + if done and trial not in self._completed_trials: self._completed_trials.add(trial) - if has_verbosity(Verbosity.V3_TRIAL_DETAILS) and ( - done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL - ): + + if should_print: + if IS_NOTEBOOK: + self.display_result(trial, result, error, done) + else: + self.print_result(trial, result, error, done) + + self._last_print[trial] = time.time() + + def print_result(self, trial: Trial, result: Dict, error: bool, done: bool): + """Print the most recent results for the given trial to stdout. + + Args: + trial: Trial for which results are to be printed + result: Result to be printed + error: True if an error has occurred, False otherwise + done: True if the trial is finished, False otherwise + """ + if has_verbosity(Verbosity.V3_TRIAL_DETAILS): print("Result for {}:".format(trial)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) - self._last_print[trial] = time.time() - elif has_verbosity(Verbosity.V2_TRIAL_NORM) and ( - done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL - ): - info = "" - if done: - info = " This trial completed." + elif has_verbosity(Verbosity.V2_TRIAL_NORM): metric_name = self._metric or "_metric" metric_value = result.get(metric_name, -99.0) + error_file = os.path.join(trial.logdir, "error.txt") + + info = "" + if done: + info = " This trial completed." print_result_str = self._print_result(result) self._last_result_str[trial] = print_result_str - error_file = os.path.join(trial.logdir, "error.txt") - if error: message = ( f"The trial {trial} errored with " @@ -1191,7 +1390,76 @@ def log_result(self, trial: "Trial", result: Dict, error: bool = False): ) print(message) - self._last_print[trial] = time.time() + + def generate_trial_table( + self, trials: Dict[Trial, Dict], columns: List[str] + ) -> str: + """Generate an HTML table of trial progress info. + + Trials (rows) are sorted by name; progress stats (columns) are sorted + as well. + + Args: + trials: Trials and their associated latest results + columns: Columns to show in the table; must be a list of valid + keys for each Trial result + + Returns: + HTML template containing a rendered table of progress info + """ + data = [] + columns = sorted(columns) + + sorted_trials = collections.OrderedDict( + sorted(self._last_result.items(), key=lambda item: str(item[0])) + ) + for trial, result in sorted_trials.items(): + data.append([str(trial)] + [result.get(col, "") for col in columns]) + + return Template("trial_progress.html.j2").render( + table=tabulate( + data, tablefmt="html", headers=["Trial name"] + columns, showindex=False + ) + ) + + def display_result(self, trial: Trial, result: Dict, error: bool, done: bool): + """Display a formatted HTML table of trial progress results. + + Trial progress is only shown if verbosity is set to level 2 or 3. + + Args: + trial: Trial for which results are to be printed + result: Result to be printed + error: True if an error has occurred, False otherwise + done: True if the trial is finished, False otherwise + """ + from IPython.display import display, HTML + + self._last_result[trial] = result + if has_verbosity(Verbosity.V3_TRIAL_DETAILS): + ignored_keys = { + "config", + "hist_stats", + } + + elif has_verbosity(Verbosity.V2_TRIAL_NORM): + ignored_keys = { + "config", + "hist_stats", + "trial_id", + "experiment_tag", + "done", + } | set(AUTO_RESULT_KEYS) + else: + return + + table = self.generate_trial_table( + self._last_result, set(result.keys()) - ignored_keys + ) + if not self._display_handle: + self._display_handle = display(HTML(table), display_id=True) + else: + self._display_handle.update(HTML(table)) def _print_result(self, result: Dict): if self._progress_metrics: diff --git a/python/ray/widgets/templates/trial_progress.html.j2 b/python/ray/widgets/templates/trial_progress.html.j2 new file mode 100644 index 000000000000..f3a323193e7f --- /dev/null +++ b/python/ray/widgets/templates/trial_progress.html.j2 @@ -0,0 +1,17 @@ +
+

Trial Progress

+ {{ table }} +
+ diff --git a/python/ray/widgets/templates/tune_status.html.j2 b/python/ray/widgets/templates/tune_status.html.j2 new file mode 100644 index 000000000000..df422f89af4e --- /dev/null +++ b/python/ray/widgets/templates/tune_status.html.j2 @@ -0,0 +1,49 @@ +
+
+
+

Tune Status

+ {{ status_table }} +
+
+
+

System Info

+ {{ sys_info_message }} +
+ {{ messages }} +
+
+
+

Trial Status

+ {{ trial_progress }} +
+
+ diff --git a/python/ray/widgets/templates/tune_status_messages.html.j2 b/python/ray/widgets/templates/tune_status_messages.html.j2 new file mode 100644 index 000000000000..da8e75f5f58d --- /dev/null +++ b/python/ray/widgets/templates/tune_status_messages.html.j2 @@ -0,0 +1,25 @@ +
+
+

Messages

+ {{ memory_message }} + {{ trial_progress_messages }} + {{ trial_errors }} +
+