Skip to content

Commit

Permalink
[air/output] Print experiment information at experiment start (#34952)
Browse files Browse the repository at this point in the history
Includes #34788

This PR prints experiment information at the start of the experiment.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored May 4, 2023
1 parent bedecac commit 581967f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
59 changes: 58 additions & 1 deletion python/ray/tune/experimental/output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import List, Dict, Optional, Tuple, Any, TYPE_CHECKING, Collection

import contextlib
Expand Down Expand Up @@ -463,6 +464,26 @@ def __init__(self, verbosity: AirVerbosity):
self._start_time = time.time()
self._last_heartbeat_time = 0

def experiment_started(
self,
experiment_name: str,
experiment_path: str,
searcher_str: str,
scheduler_str: str,
total_num_samples: int,
tensorboard_path: Optional[str] = None,
**kwargs,
):
print(f"\nView detailed results here: {experiment_path}")

if tensorboard_path:
print(
f"To visualize your results with TensorBoard, run: "
f"`tensorboard --logdir {tensorboard_path}`"
)

print("")

@property
def _time_heartbeat_str(self):
current_time_str, running_for_str = _get_time_str(self._start_time, time.time())
Expand Down Expand Up @@ -557,6 +578,42 @@ def _print_heartbeat(self, trials, *sys_args):


class TuneTerminalReporter(TuneReporterBase):
def experiment_started(
self,
experiment_name: str,
experiment_path: str,
searcher_str: str,
scheduler_str: str,
total_num_samples: int,
tensorboard_path: Optional[str] = None,
**kwargs,
):
if total_num_samples > sys.maxsize:
total_num_samples_str = "infinite"
else:
total_num_samples_str = str(total_num_samples)

print(
tabulate(
[
["Search algorithm", searcher_str],
["Scheduler", scheduler_str],
["Number of trials", total_num_samples_str],
],
headers=["Configuration for experiment", experiment_name],
tablefmt=AIR_TABULATE_TABLEFMT,
)
)
super().experiment_started(
experiment_name=experiment_name,
experiment_path=experiment_path,
searcher_str=searcher_str,
scheduler_str=scheduler_str,
total_num_samples=total_num_samples,
tensorboard_path=tensorboard_path,
**kwargs,
)

def _print_heartbeat(self, trials, *sys_args):
if self._verbosity < self._heartbeat_threshold:
return
Expand All @@ -575,7 +632,7 @@ def _print_heartbeat(self, trials, *sys_args):
tabulate(
all_infos,
headers=header,
tablefmt="simple",
tablefmt=AIR_TABULATE_TABLEFMT,
showindex=False,
)
)
Expand Down
15 changes: 15 additions & 0 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)

from ray.tune.impl.placeholder import create_resolvers_map, inject_placeholders
from ray.tune.logger import TBXLoggerCallback
from ray.tune.progress_reporter import (
ProgressReporter,
_detect_reporter,
Expand Down Expand Up @@ -957,10 +958,24 @@ class and registered trainables.
with contextlib.ExitStack() as stack:
from ray.tune.experimental.output import TuneRichReporter

if any(isinstance(cb, TBXLoggerCallback) for cb in callbacks):
tensorboard_path = runner._local_experiment_path
else:
tensorboard_path = None

if air_progress_reporter and isinstance(
air_progress_reporter, TuneRichReporter
):
stack.enter_context(air_progress_reporter.with_live())
elif air_progress_reporter:
air_progress_reporter.experiment_started(
experiment_name=runner._experiment_dir_name,
experiment_path=runner.experiment_path,
searcher_str=search_alg.__class__.__name__,
scheduler_str=scheduler.__class__.__name__,
total_num_samples=search_alg.total_samples,
tensorboard_path=tensorboard_path,
)

try:
while (
Expand Down

0 comments on commit 581967f

Please sign in to comment.