Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Persist Datasets statistics to log file #30557

Merged
merged 25 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/ray/data/_internal/dataset_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import os

import ray
from ray._private.ray_constants import LOGGER_FORMAT, LOGGER_LEVEL


class DatasetLogger:
"""Logger for Ray Datasets which, in addition to logging to stdout,
also writes to a separate log file at `DatasetLogger.DEFAULT_DATASET_LOG_PATH`.
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
"""

DEFAULT_DATASET_LOG_PATH = "logs/ray-data.log"

def __init__(self, log_name: str):
"""Initialize DatasetLogger for a given `log_name`.

Args:
log_name: Name of logger (usually passed into `logging.getLogger(...)`)
"""
# Logger used to logging to log file (in addition to the root logger,
# which logs to stdout as normal). We set `logger.propagate` to False
# to ensure the file logger only logs to the file, and not stdout, by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation seems confusing to me since the class-level comment says it writes to the file in addition to stdout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworded the comments and documentation to clarify here and in the class level comment, let me know if things are still confusing here. thanks!

# For logging calls made with the parameter `log_to_stdout = False`,
# `logger.propagate` will be set to `False` in order to prevent the
# root logger from writing the log to stdout.
self.logger = logging.getLogger(f"{log_name}.logfile")
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
# We need to set the log level again when explicitly
# initializing a new logger (otherwise can have undesirable level).
self.logger.setLevel(LOGGER_LEVEL.upper())

# Add log handler which writes to a separate Datasets log file
# at `DatasetLogger.DEFAULT_DATASET_LOG_PATH`
global_node = ray._private.worker._global_node
if global_node is not None:
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
# With current implementation, we can only get session_dir
# after ray.init() is called. A less hacky way could potentially fix this
session_dir = global_node.get_session_dir_path()
self.datasets_log_path = os.path.join(
session_dir,
DatasetLogger.DEFAULT_DATASET_LOG_PATH,
)
# Add a FileHandler to write to the specific Ray Datasets log file,
# using the standard default logger format used by the root logger
file_log_handler = logging.FileHandler(self.datasets_log_path)
file_log_formatter = logging.Formatter(fmt=LOGGER_FORMAT)
file_log_handler.setFormatter(file_log_formatter)
self.logger.addHandler(file_log_handler)

def get_logger(self, log_to_stdout: bool = True):
"""
Returns the underlying Logger, with the `propagate` attribute set
to the same value as `log_to_stdout`. For example, when
`log_to_stdout = False`, we do not want the `DatasetLogger` to
propagate up to the base Logger which writes to stdout.

This is a workaround needed due to the DatasetLogger wrapper object
not having access to the log caller's scope in Python <3.8.
In the future, with Python 3.8 support, we can use the `stacklevel` arg,
which allows the logger to fetch the correct calling file/line:
`logger.info(msg="Hello world", stacklevel=2)`
"""
self.logger.propagate = log_to_stdout
return self.logger
11 changes: 7 additions & 4 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import functools
import itertools
import logging
import uuid
from typing import (
TYPE_CHECKING,
Expand All @@ -27,6 +26,7 @@
get_compute,
is_task_compute,
)
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.stats import DatasetStats
from ray.data.block import Block
Expand All @@ -40,7 +40,7 @@
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]


logger = logging.getLogger(__name__)
logger = DatasetLogger(__name__)
scottjlee marked this conversation as resolved.
Show resolved Hide resolved


class Stage:
Expand Down Expand Up @@ -319,8 +319,11 @@ def execute(
else:
stats = stats_builder.build(blocks)
stats.dataset_uuid = uuid.uuid4().hex
if context.enable_auto_log_stats:
logger.info(stats.summary_string(include_parent=False))

stats_summary_string = stats.summary_string(include_parent=False)
logger.get_logger(log_to_stdout=context.enable_auto_log_stats).info(
msg=stats_summary_string,
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
)
# Set the snapshot to the output of the final stage.
self._snapshot_blocks = blocks
self._snapshot_stats = stats
Expand Down
50 changes: 50 additions & 0 deletions python/ray/data/tests/test_dataset_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from ray.tests.conftest import * # noqa

import os
import re
import logging

from datetime import datetime

import ray
from ray.data._internal.dataset_logger import DatasetLogger


def test_dataset_logger(shutdown_only):
ray.init()
log_name, msg = "test_name", "test_message_1234"
logger = DatasetLogger(log_name)
logger.get_logger().info(msg)

# Read from log file, and parse each component of emitted log row
session_dir = ray._private.worker._global_node.get_session_dir_path()
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
log_file_path = os.path.join(session_dir, DatasetLogger.DEFAULT_DATASET_LOG_PATH)
with open(log_file_path, "r") as f:
raw_logged_msg = f.read()
(
logged_ds,
logged_ts,
logged_level,
logged_filepath,
sep,
logged_msg,
) = raw_logged_msg.split()

# Could not use freezegun to test exact timestamp value
# (values off by some milliseconds), so instead we check
# for correct timestamp format.
try:
datetime.strptime(f"{logged_ds} {logged_ts}", "%Y-%m-%d %H:%M:%S,%f")
except ValueError:
raise Exception(f"Invalid log timestamp: {logged_ds} {logged_ts}")

assert logged_level == logging.getLevelName(logging.INFO)
assert re.match(r"test_dataset_logger.py:\d+", logged_filepath)
assert logged_msg == msg


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
121 changes: 63 additions & 58 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import ray
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data.context import DatasetContext
from ray.tests.conftest import * # noqa

Expand All @@ -21,43 +22,45 @@ def canonicalize(stats: str) -> str:
return s4


@patch("ray.data._internal.plan.logger")
def test_dataset_stats_basic(
mock_logger, ray_start_regular_shared, enable_auto_log_stats
):
def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
)
with patch.object(logger, "info") as mock_logger:
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_kwargs["msg"])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)

ds = ds.map(lambda x: x)
if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
)

ds = ds.map(lambda x: x)
if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_kwargs["msg"])
== """Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)
for batch in ds.iter_batches():
pass
stats = canonicalize(ds.stats())
Expand Down Expand Up @@ -227,70 +230,72 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
)


@patch("ray.data._internal.plan.logger")
def test_dataset_pipeline_stats_basic(
mock_logger, ray_start_regular_shared, enable_auto_log_stats
):
def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
)
with patch.object(logger, "info") as mock_logger:
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)

if enable_auto_log_stats:
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_kwargs["msg"])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)

pipe = ds.repeat(5)
pipe = pipe.map(lambda x: x)
if enable_auto_log_stats:
# Stats only include first stage, and not for pipelined map
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N read->map_batches: N/N blocks executed in T
)

pipe = ds.repeat(5)
pipe = pipe.map(lambda x: x)
if enable_auto_log_stats:
# Stats only include first stage, and not for pipelined map
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_kwargs["msg"])
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)

stats = canonicalize(pipe.stats())
assert "No stats available" in stats, stats
for batch in pipe.iter_batches():
pass
stats = canonicalize(pipe.stats())
assert "No stats available" in stats, stats
for batch in pipe.iter_batches():
pass

if enable_auto_log_stats:
# Now stats include the pipelined map stage
logger_args, logger_kwargs = mock_logger.info.call_args
assert (
canonicalize(logger_args[0])
== """Stage N map: N/N blocks executed in T
if enable_auto_log_stats:
# Now stats include the pipelined map stage
logger_args, logger_kwargs = mock_logger.call_args
assert (
canonicalize(logger_kwargs["msg"])
== """Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
)

stats = canonicalize(pipe.stats())
assert (
stats
== """== Pipeline Window N ==
stats = canonicalize(pipe.stats())
assert (
stats
== """== Pipeline Window N ==
Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
Expand Down Expand Up @@ -341,7 +346,7 @@ def test_dataset_pipeline_stats_basic(
* In user code: T
* Total time: T
"""
)
)


def test_dataset_pipeline_cache_cases(ray_start_regular_shared):
Expand Down