Skip to content

Commit

Permalink
[Datasets] Persist Datasets statistics to log file (ray-project#30557)
Browse files Browse the repository at this point in the history
Currently, when we print Dataset stats after execution, there is no way to retrieve this information in case of job failure/crash. By persisting the logs to a separate file, we can access the stats which could be helpful for debugging. By default, this is configured to write to /logs/ray-data.log.

The new logger, DatasetLogger, is configured to always write logs to the ray-data.log file, and optionally also writes to stdout (this is enabled by default). The motivation behind this is so that users can easily use the specific log file to filter for Dataset logs, while still maintaining console logs for those who use them.

Signed-off-by: tmynn <[email protected]>
  • Loading branch information
scottjlee authored and tamohannes committed Jan 25, 2023
1 parent d15342b commit 912530e
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 62 deletions.
75 changes: 75 additions & 0 deletions python/ray/data/_internal/dataset_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import os

import ray
from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL


class DatasetLogger:
"""Logger for Ray Datasets which writes logs to a separate log file
at `DatasetLogger.DEFAULT_DATASET_LOG_PATH`. Can optionally turn off
logging to stdout to reduce clutter (but always logs to the aformentioned
Datasets-specific log file).
After initialization, always use the `get_logger()` method to correctly
set whether to log to stdout. Example usage:
```
logger = DatasetLogger(__name__)
logger.get_logger().info("This logs to file and stdout")
logger.get_logger(log_to_stdout=False).info("This logs to file only)
logger.get_logger().warning("Can call the usual Logger methods")
```
"""

DEFAULT_DATASET_LOG_PATH = "logs/ray-data.log"

def __init__(self, log_name: str):
"""Initialize DatasetLogger for a given `log_name`.
Args:
log_name: Name of logger (usually passed into `logging.getLogger(...)`)
"""
# Logger used to logging to log file (in addition to the root logger,
# which logs to stdout as normal). For logging calls made with the
# parameter `log_to_stdout = False`, `_logger.propagate` will be set
# to `False` in order to prevent the root logger from writing the log
# to stdout.
self._logger = logging.getLogger(f"{log_name}.logfile")
# We need to set the log level again when explicitly
# initializing a new logger (otherwise can have undesirable level).
self._logger.setLevel(LOGGER_LEVEL.upper())

# Add log handler which writes to a separate Datasets log file
# at `DatasetLogger.DEFAULT_DATASET_LOG_PATH`
global_node = ray._private.worker._global_node
if global_node is not None:
# With current implementation, we can only get session_dir
# after ray.init() is called. A less hacky way could potentially fix this
session_dir = global_node.get_session_dir_path()
self.datasets_log_path = os.path.join(
session_dir,
DatasetLogger.DEFAULT_DATASET_LOG_PATH,
)
# Add a FileHandler to write to the specific Ray Datasets log file,
# using the standard default logger format used by the root logger
file_log_handler = logging.FileHandler(self.datasets_log_path)
file_log_formatter = logging.Formatter(fmt=LOGGER_FORMAT)
file_log_handler.setFormatter(file_log_formatter)
self._logger.addHandler(file_log_handler)

def get_logger(self, log_to_stdout: bool = True):
"""
Returns the underlying Logger, with the `propagate` attribute set
to the same value as `log_to_stdout`. For example, when
`log_to_stdout = False`, we do not want the `DatasetLogger` to
propagate up to the base Logger which writes to stdout.
This is a workaround needed due to the DatasetLogger wrapper object
not having access to the log caller's scope in Python <3.8.
In the future, with Python 3.8 support, we can use the `stacklevel` arg,
which allows the logger to fetch the correct calling file/line and
also removes the need for this getter method:
`logger.info(msg="Hello world", stacklevel=2)`
"""
self._logger.propagate = log_to_stdout
return self._logger
11 changes: 7 additions & 4 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import functools
import itertools
import logging
import uuid
from typing import (
TYPE_CHECKING,
Expand All @@ -27,6 +26,7 @@
get_compute,
is_task_compute,
)
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.stats import DatasetStats
from ray.data.block import Block
Expand All @@ -40,7 +40,7 @@
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]


logger = logging.getLogger(__name__)
logger = DatasetLogger(__name__)


class Stage:
Expand Down Expand Up @@ -326,8 +326,11 @@ def execute(
else:
stats = stats_builder.build(blocks)
stats.dataset_uuid = uuid.uuid4().hex
if context.enable_auto_log_stats:
logger.info(stats.summary_string(include_parent=False))

stats_summary_string = stats.summary_string(include_parent=False)
logger.get_logger(log_to_stdout=context.enable_auto_log_stats).info(
stats_summary_string,
)
# Set the snapshot to the output of the final stage.
self._snapshot_blocks = blocks
self._snapshot_stats = stats
Expand Down
50 changes: 50 additions & 0 deletions python/ray/data/tests/test_dataset_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from ray.tests.conftest import * # noqa

import os
import re
import logging

from datetime import datetime

import ray
from ray.data._internal.dataset_logger import DatasetLogger


def test_dataset_logger(shutdown_only):
ray.init()
log_name, msg = "test_name", "test_message_1234"
logger = DatasetLogger(log_name)
logger.get_logger().info(msg)

# Read from log file, and parse each component of emitted log row
session_dir = ray._private.worker._global_node.get_session_dir_path()
log_file_path = os.path.join(session_dir, DatasetLogger.DEFAULT_DATASET_LOG_PATH)
with open(log_file_path, "r") as f:
raw_logged_msg = f.read()
(
logged_ds,
logged_ts,
logged_level,
logged_filepath,
sep,
logged_msg,
) = raw_logged_msg.split()

# Could not use freezegun to test exact timestamp value
# (values off by some milliseconds), so instead we check
# for correct timestamp format.
try:
datetime.strptime(f"{logged_ds} {logged_ts}", "%Y-%m-%d %H:%M:%S,%f")
except ValueError:
raise Exception(f"Invalid log timestamp: {logged_ds} {logged_ts}")

assert logged_level == logging.getLevelName(logging.INFO)
assert re.match(r"test_dataset_logger.py:\d+", logged_filepath)
assert logged_msg == msg


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
121 changes: 63 additions & 58 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import ray
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data.context import DatasetContext
from ray.tests.conftest import * # noqa

Expand All @@ -21,43 +22,45 @@ def canonicalize(stats: str) -> str:
return s4


@patch("ray.data._internal.plan.logger")
def test_dataset_stats_basic(
mock_logger, ray_start_regular_shared, enable_auto_log_stats
):
def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
)
with patch.object(logger, "info") as mock_logger:
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)

ds = ds.map(lambda x: x)
if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
)

ds = ds.map(lambda x: x)
if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)
for batch in ds.iter_batches():
pass
stats = canonicalize(ds.stats())
Expand Down Expand Up @@ -227,70 +230,72 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
)


@patch("ray.data._internal.plan.logger")
def test_dataset_pipeline_stats_basic(
mock_logger, ray_start_regular_shared, enable_auto_log_stats
):
def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
)
with patch.object(logger, "info") as mock_logger:
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)

pipe = ds.repeat(5)
pipe = pipe.map(lambda x: x)
if enable_auto_log_stats:
# Stats only include first stage, and not for pipelined map
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
)

pipe = ds.repeat(5)
pipe = pipe.map(lambda x: x)
if enable_auto_log_stats:
# Stats only include first stage, and not for pipelined map
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)

stats = canonicalize(pipe.stats())
assert "No stats available" in stats, stats
for batch in pipe.iter_batches():
pass
stats = canonicalize(pipe.stats())
assert "No stats available" in stats, stats
for batch in pipe.iter_batches():
pass

if enable_auto_log_stats:
# Now stats include the pipelined map stage
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
if enable_auto_log_stats:
# Now stats include the pipelined map stage
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)

stats = canonicalize(pipe.stats())
assert (
stats
== """== Pipeline Window N ==
stats = canonicalize(pipe.stats())
assert (
stats
== """== Pipeline Window N ==
Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
Expand Down Expand Up @@ -341,7 +346,7 @@ def test_dataset_pipeline_stats_basic(
* In user code: T
* Total time: T
"""
)
)


def test_dataset_pipeline_cache_cases(ray_start_regular_shared):
Expand Down

0 comments on commit 912530e

Please sign in to comment.