Skip to content

Commit

Permalink
[AIR] Add __repr__ to AIR classes (ray-project#27006)
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
bveeramani authored and Stefan van der Kleij committed Aug 18, 2022
1 parent cdcee32 commit b864104
Show file tree
Hide file tree
Showing 37 changed files with 471 additions and 75 deletions.
8 changes: 8 additions & 0 deletions python/ray/air/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ py_test(
deps = [":ml_lib"]
)

py_test(
name = "test_configs",
size = "small",
srcs = ["tests/test_configs.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)

py_test(
name = "test_data_batch_conversion",
size = "small",
Expand Down
4 changes: 4 additions & 0 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def __init__(
self._uri: Optional[str] = uri
self._obj_ref: Optional[ray.ObjectRef] = obj_ref

def __repr__(self):
parameter, argument = self.get_internal_representation()
return f"{self.__class__.__name__}({parameter}={argument})"

@classmethod
def from_bytes(cls, data: bytes) -> "Checkpoint":
"""Create a checkpoint from the given byte string.
Expand Down
76 changes: 74 additions & 2 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Mapping, Optional, Union, Tuple
from dataclasses import _MISSING_TYPE, dataclass, fields
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Union,
Tuple,
)

from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -28,6 +38,44 @@
MIN = "min"


def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> str:
"""A utility function to elegantly represent dataclasses.
In contrast to the default dataclass `__repr__`, which shows all parameters, this
function only shows parameters with non-default values.
Args:
obj: The dataclass to represent.
default_values: An optional dictionary that maps field names to default values.
Use this parameter to specify default values that are generated dynamically
(e.g., in `__post_init__` or by a `default_factory`). If a default value
isn't specified in `default_values`, then the default value is inferred from
the `dataclass`.
Returns:
A representation of the dataclass.
"""
if default_values is None:
default_values = {}

non_default_values = {} # Maps field name to value.

for field in fields(obj):
value = getattr(obj, field.name)
default_value = default_values.get(field.name, field.default)
is_required = isinstance(field.default, _MISSING_TYPE)
if is_required or value != default_value:
non_default_values[field.name] = value

string = f"{obj.__class__.__name__}("
string += ", ".join(
f"{name}={value!r}" for name, value in non_default_values.items()
)
string += ")"

return string


@dataclass
@PublicAPI(stability="alpha")
class ScalingConfig:
Expand Down Expand Up @@ -87,6 +135,9 @@ def __post_init__(self):
"`resources_per_worker."
)

def __repr__(self):
return _repr_dataclass(self)

def __eq__(self, o: "ScalingConfig") -> bool:
if not isinstance(o, type(self)):
return False
Expand Down Expand Up @@ -272,6 +323,9 @@ class DatasetConfig:
# True by default.
randomize_block_order: Optional[bool] = None

def __repr__(self):
return _repr_dataclass(self)

def fill_defaults(self) -> "DatasetConfig":
"""Return a copy of this config with all default values filled in."""
return DatasetConfig(
Expand Down Expand Up @@ -404,6 +458,9 @@ def __post_init__(self):
"fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
)

def __repr__(self):
return _repr_dataclass(self)


@dataclass
@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -468,6 +525,9 @@ def __post_init__(self):
f"checkpoint_frequency must be >=0, got {self.checkpoint_frequency}"
)

def __repr__(self):
return _repr_dataclass(self)

@property
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
Expand Down Expand Up @@ -550,3 +610,15 @@ def __post_init__(self):

if not self.checkpoint_config:
self.checkpoint_config = CheckpointConfig()

def __repr__(self):
from ray.tune.syncer import SyncConfig

return _repr_dataclass(
self,
default_values={
"failure_config": FailureConfig(),
"sync_config": SyncConfig(),
"checkpoint_config": CheckpointConfig(),
},
)
4 changes: 4 additions & 0 deletions python/ray/air/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@

# Name to use for the column when representing tensors in table format.
TENSOR_COLUMN_NAME = "__value__"

# The maximum length of strings returned by `__repr__` for AIR objects constructed with
# default values.
MAX_REPR_LENGTH = int(80 * 1.5)
13 changes: 12 additions & 1 deletion python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import re
import shutil
import tempfile
import unittest
Expand All @@ -8,7 +9,7 @@
import ray
from ray.air._internal.remote_storage import delete_at_uri, _ensure_directory
from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
from ray.air.constants import PREPROCESSOR_KEY
from ray.air.constants import MAX_REPR_LENGTH, PREPROCESSOR_KEY
from ray.data import Preprocessor


Expand All @@ -20,6 +21,16 @@ def transform_batch(self, df):
return df * self.multiplier


def test_repr():
checkpoint = Checkpoint(data_dict={"foo": "bar"})

representation = repr(checkpoint)

assert len(representation) < MAX_REPR_LENGTH
pattern = re.compile("^Checkpoint\\((.*)\\)$")
assert pattern.match(representation)


class CheckpointsConversionTest(unittest.TestCase):
def setUp(self):
self.tmpdir = os.path.realpath(tempfile.mkdtemp())
Expand Down
41 changes: 41 additions & 0 deletions python/ray/air/tests/test_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from ray.air.config import (
ScalingConfig,
DatasetConfig,
FailureConfig,
CheckpointConfig,
RunConfig,
)
from ray.air.constants import MAX_REPR_LENGTH


@pytest.mark.parametrize(
"config",
[
ScalingConfig(),
ScalingConfig(use_gpu=True),
DatasetConfig(),
DatasetConfig(fit=True),
FailureConfig(),
FailureConfig(max_failures=2),
CheckpointConfig(),
CheckpointConfig(num_to_keep=1),
RunConfig(),
RunConfig(name="experiment"),
RunConfig(failure_config=FailureConfig()),
],
)
def test_repr(config):
representation = repr(config)

assert eval(representation) == config
assert len(representation) < MAX_REPR_LENGTH


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
2 changes: 1 addition & 1 deletion python/ray/data/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame":

def __repr__(self):
fn_name = getattr(self.fn, "__name__", self.fn)
return f"BatchMapper(fn={fn_name})"
return f"{self.__class__.__name__}(fn={fn_name})"
3 changes: 2 additions & 1 deletion python/ray/data/preprocessors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ def _transform_batch(self, df: "DataBatchType") -> "DataBatchType":
return df

def __repr__(self):
return f"Chain(preprocessors={self.preprocessors})"
arguments = ", ".join(repr(preprocessor) for preprocessor in self.preprocessors)
return f"{self.__class__.__name__}({arguments})"
43 changes: 26 additions & 17 deletions python/ray/data/preprocessors/concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,24 @@ def __init__(
raise_if_missing: bool = False,
):
self.output_column_name = output_column_name
self.included_columns = include
self.excluded_columns = exclude or []
self.include = include
self.exclude = exclude or []
self.dtype = dtype
self.raise_if_missing = raise_if_missing

def _validate(self, df: pd.DataFrame):
total_columns = set(df)
if self.excluded_columns and self.raise_if_missing:
missing_columns = set(self.excluded_columns) - total_columns.intersection(
set(self.excluded_columns)
if self.exclude and self.raise_if_missing:
missing_columns = set(self.exclude) - total_columns.intersection(
set(self.exclude)
)
if missing_columns:
raise ValueError(
f"Missing columns specified in 'exclude': {missing_columns}"
)
if self.included_columns and self.raise_if_missing:
missing_columns = set(self.included_columns) - total_columns.intersection(
set(self.included_columns)
if self.include and self.raise_if_missing:
missing_columns = set(self.include) - total_columns.intersection(
set(self.include)
)
if missing_columns:
raise ValueError(
Expand All @@ -84,10 +84,10 @@ def _transform_pandas(self, df: pd.DataFrame):
self._validate(df)

included_columns = set(df)
if self.included_columns: # subset of included columns
included_columns = set(self.included_columns)
if self.include: # subset of included columns
included_columns = set(self.include)

columns_to_concat = list(included_columns - set(self.excluded_columns))
columns_to_concat = list(included_columns - set(self.exclude))
concatenated = df[columns_to_concat].to_numpy(dtype=self.dtype)
df = df.drop(columns=columns_to_concat)
try:
Expand All @@ -98,9 +98,18 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
return (
f"Concatenator(output_column_name={self.output_column_name}, "
f"include={self.included_columns}, "
f"exclude={self.excluded_columns}, "
f"dtype={self.dtype})"
)
default_values = {
"output_column_name": "concat_out",
"include": None,
"exclude": [],
"dtype": None,
"raise_if_missing": False,
}

non_default_arguments = []
for parameter, default_value in default_values.items():
value = getattr(self, parameter)
if value != default_value:
non_default_arguments.append(f"{parameter}={value}")

return f"{self.__class__.__name__}({', '.join(non_default_arguments)})"
24 changes: 13 additions & 11 deletions python/ray/data/preprocessors/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ def list_as_category(element):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"OrdinalEncoder(columns={self.columns}, stats={stats}, "
f"encode_lists={self.encode_lists})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"encode_lists={self.encode_lists!r})"
)


Expand Down Expand Up @@ -191,8 +190,10 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"OneHotEncoder(columns={self.columns}, stats={stats})"
return (
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"max_categories={self.max_categories!r})"
)


class MultiHotEncoder(Preprocessor):
Expand Down Expand Up @@ -277,8 +278,10 @@ def encode_list(element: list, *, name: str):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"MultiHotEncoder(columns={self.columns}, stats={stats})"
return (
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"max_categories={self.max_categories!r})"
)


class LabelEncoder(Preprocessor):
Expand Down Expand Up @@ -313,8 +316,7 @@ def column_label_encoder(s: pd.Series):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"LabelEncoder(label_column={self.label_column}, stats={stats})"
return f"{self.__class__.__name__}(label_column={self.label_column!r})"


class Categorizer(Preprocessor):
Expand Down Expand Up @@ -367,9 +369,9 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"<Categorizer columns={self.columns} dtypes={self.dtypes} stats={stats}>"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"dtypes={self.dtypes!r})"
)


Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/preprocessors/hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ def row_feature_hasher(row):

def __repr__(self):
return (
f"FeatureHasher(columns={self.columns}, num_features={self.num_features})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"num_features={self.num_features!r})"
)
8 changes: 2 additions & 6 deletions python/ray/data/preprocessors/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,9 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"SimpleImputer("
f"columns={self.columns}, "
f"strategy={self.strategy}, "
f"fill_value={self.fill_value}, "
f"stats={stats})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"strategy={self.strategy!r}, fill_value={self.fill_value!r})"
)


Expand Down
Loading

0 comments on commit b864104

Please sign in to comment.