Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add __repr__ to AIR classes #27006

Merged
merged 35 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b51aa74
Add `__repr__` for checkpoints
bveeramani Jul 26, 2022
ca2e9a7
Add config `__repr__`s
bveeramani Jul 26, 2022
50b1b04
Add preprocessor `__repr__`s
bveeramani Jul 26, 2022
bb36088
Add trainer `__repr__`
bveeramani Jul 26, 2022
40468f6
Update `Preprocessor`s
bveeramani Jul 26, 2022
ad08a51
Fix stuff
bveeramani Jul 26, 2022
448a67c
Add predictor `__repr__`s
bveeramani Jul 26, 2022
829cf73
Appease lint
bveeramani Jul 26, 2022
6e881c3
Update tensorflow_predictor.py
bveeramani Jul 26, 2022
7c90cf2
Merge branch 'master' into bveeramani/air-repr2
bveeramani Jul 26, 2022
32737da
Fix docstring example
bveeramani Jul 26, 2022
a7abc5a
Update preprocessor reprs
bveeramani Jul 26, 2022
1632231
More preprocessor fix
bveeramani Jul 26, 2022
c509445
Update base_trainer.py
bveeramani Jul 26, 2022
9c62b92
Update config.py
bveeramani Jul 26, 2022
83b3e89
Update predictors
bveeramani Jul 26, 2022
3cc7836
Remove example
bveeramani Jul 26, 2022
819259d
Merge branch 'master' into bveeramani/air-repr2
bveeramani Jul 26, 2022
a327638
Resolve merge issues
bveeramani Jul 26, 2022
6ee74e3
Add preprocessor repr tests
bveeramani Jul 26, 2022
13efed3
Add checkpoint tests
bveeramani Jul 26, 2022
c7bdad0
Appease lint
bveeramani Jul 26, 2022
a54fdb6
Add some tests
bveeramani Jul 26, 2022
e3924b9
Add tests
bveeramani Jul 26, 2022
e0682ff
Remove whitespace
bveeramani Jul 26, 2022
74e297d
Add tests to `BUILD`
bveeramani Jul 26, 2022
125e116
Update python/ray/train/tests/test_batch_predictor.py
richardliaw Jul 27, 2022
85b6966
Fix dataset tests
bveeramani Jul 27, 2022
9f33bc7
Fix `DatasetPipeline`
bveeramani Jul 27, 2022
a304fc4
Format files
bveeramani Jul 27, 2022
28d71f4
Merge branch 'master' into bveeramani/air-repr2
bveeramani Jul 27, 2022
1cfb3b6
Revert dataset changes
bveeramani Jul 27, 2022
f6ed176
Update test_dataset_config.py
bveeramani Jul 27, 2022
4091a42
Merge remote-tracking branch 'upstream/master' into bveeramani/air-repr2
bveeramani Jul 27, 2022
c05dc4b
Merge master into branch
richardliaw Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def __init__(
self._uri: Optional[str] = uri
self._obj_ref: Optional[ray.ObjectRef] = obj_ref

def __repr__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

parameter, argument = self.get_internal_representation()
return f"{self.__class__.__name__}({parameter}={argument})"

@classmethod
def from_bytes(cls, data: bytes) -> "Checkpoint":
"""Create a checkpoint from the given byte string.
Expand Down
76 changes: 74 additions & 2 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Mapping, Optional, Union, Tuple
from dataclasses import _MISSING_TYPE, dataclass, fields
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Union,
Tuple,
)

from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -27,6 +37,44 @@
MIN = "min"


def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> str:
"""A utility function to elegantly represent dataclasses.

In contrast to the default dataclass `__repr__`, which shows all parameters, this
function only shows parameters with non-default values.

Args:
obj: The dataclass to represent.
default_values: An optional dictionary that maps field names to default values.
Use this parameter to specify default values that are generated dynamically
(e.g., in `__post_init__` or by a `default_factory`). If a default value
isn't specified in `default_values`, then the default value is inferred from
the `dataclass`.

Returns:
A representation of the dataclass.
"""
if default_values is None:
default_values = {}

non_default_values = {} # Maps field name to value.

for field in fields(obj):
value = getattr(obj, field.name)
default_value = default_values.get(field.name, field.default)
is_required = isinstance(field.default, _MISSING_TYPE)
if is_required or value != default_value:
non_default_values[field.name] = value

string = f"{obj.__class__.__name__}("
string += ", ".join(
f"{name}={value!r}" for name, value in non_default_values.items()
)
string += ")"

return string


@dataclass
@PublicAPI(stability="alpha")
class ScalingConfig:
Expand Down Expand Up @@ -86,6 +134,9 @@ def __post_init__(self):
"`resources_per_worker."
)

def __repr__(self):
return _repr_dataclass(self)

def __eq__(self, o: "ScalingConfig") -> bool:
if not isinstance(o, type(self)):
return False
Expand Down Expand Up @@ -262,6 +313,9 @@ class DatasetConfig:
# True by default.
randomize_block_order: Optional[bool] = None

def __repr__(self):
return _repr_dataclass(self)

def fill_defaults(self) -> "DatasetConfig":
"""Return a copy of this config with all default values filled in."""
return DatasetConfig(
Expand Down Expand Up @@ -394,6 +448,9 @@ def __post_init__(self):
"fail_fast must be one of {bool, 'raise'}. " f"Got {self.fail_fast}."
)

def __repr__(self):
return _repr_dataclass(self)


@dataclass
@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -458,6 +515,9 @@ def __post_init__(self):
f"checkpoint_frequency must be >=0, got {self.checkpoint_frequency}"
)

def __repr__(self):
return _repr_dataclass(self)

@property
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
Expand Down Expand Up @@ -540,3 +600,15 @@ def __post_init__(self):

if not self.checkpoint_config:
self.checkpoint_config = CheckpointConfig()

def __repr__(self):
from ray.tune.syncer import SyncConfig

return _repr_dataclass(
self,
default_values={
"failure_config": FailureConfig(),
"sync_config": SyncConfig(),
"checkpoint_config": CheckpointConfig(),
},
)
4 changes: 4 additions & 0 deletions python/ray/air/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@

# Name to use for the column when representing tensors in table format.
TENSOR_COLUMN_NAME = "__value__"

# The maximum length of strings returned by `__repr__` for AIR objects constructed with
# default values.
MAX_REPR_LENGTH = int(80 * 1.5)
13 changes: 12 additions & 1 deletion python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import re
import shutil
import tempfile
import unittest
Expand All @@ -8,7 +9,7 @@
import ray
from ray.air._internal.remote_storage import delete_at_uri, _ensure_directory
from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
from ray.air.constants import PREPROCESSOR_KEY
from ray.air.constants import MAX_REPR_LENGTH, PREPROCESSOR_KEY
from ray.data import Preprocessor


Expand All @@ -20,6 +21,16 @@ def transform_batch(self, df):
return df * self.multiplier


def test_repr():
checkpoint = Checkpoint(data_dict={"foo": "bar"})

representation = repr(checkpoint)

assert len(representation) < MAX_REPR_LENGTH
pattern = re.compile("^Checkpoint\\((.*)\\)$")
assert pattern.match(representation)


class CheckpointsConversionTest(unittest.TestCase):
def setUp(self):
self.tmpdir = os.path.realpath(tempfile.mkdtemp())
Expand Down
41 changes: 41 additions & 0 deletions python/ray/air/tests/test_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from ray.air.config import (
ScalingConfig,
DatasetConfig,
FailureConfig,
CheckpointConfig,
RunConfig,
)
from ray.air.constants import MAX_REPR_LENGTH


@pytest.mark.parametrize(
"config",
[
ScalingConfig(),
ScalingConfig(use_gpu=True),
DatasetConfig(),
DatasetConfig(fit=True),
FailureConfig(),
FailureConfig(max_failures=2),
CheckpointConfig(),
CheckpointConfig(num_to_keep=1),
RunConfig(),
RunConfig(name="experiment"),
RunConfig(failure_config=FailureConfig()),
],
)
def test_repr(config):
representation = repr(config)

assert eval(representation) == config
assert len(representation) < MAX_REPR_LENGTH


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3586,7 +3586,7 @@ def __repr__(self) -> str:
schema_str = ", ".join(schema_str)
schema_str = "{" + schema_str + "}"
count = self._meta_count()
return "Dataset(num_blocks={}, num_rows={}, schema={})".format(
return "<Dataset num_blocks={} num_rows={} schema={}>".format(
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
self._plan.initial_num_blocks(), count, schema_str
)

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame":

def __repr__(self):
fn_name = getattr(self.fn, "__name__", self.fn)
return f"BatchMapper(fn={fn_name})"
return f"{self.__class__.__name__}(fn={fn_name})"
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion python/ray/data/preprocessors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ def _transform_batch(self, df: "DataBatchType") -> "DataBatchType":
return df

def __repr__(self):
return f"Chain(preprocessors={self.preprocessors})"
arguments = ", ".join(repr(preprocessor) for preprocessor in self.preprocessors)
return f"{self.__class__.__name__}({arguments})"
43 changes: 26 additions & 17 deletions python/ray/data/preprocessors/concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,24 @@ def __init__(
raise_if_missing: bool = False,
):
self.output_column_name = output_column_name
self.included_columns = include
self.excluded_columns = exclude or []
self.include = include
self.exclude = exclude or []
self.dtype = dtype
self.raise_if_missing = raise_if_missing

def _validate(self, df: pd.DataFrame):
total_columns = set(df)
if self.excluded_columns and self.raise_if_missing:
missing_columns = set(self.excluded_columns) - total_columns.intersection(
set(self.excluded_columns)
if self.exclude and self.raise_if_missing:
missing_columns = set(self.exclude) - total_columns.intersection(
set(self.exclude)
)
if missing_columns:
raise ValueError(
f"Missing columns specified in 'exclude': {missing_columns}"
)
if self.included_columns and self.raise_if_missing:
missing_columns = set(self.included_columns) - total_columns.intersection(
set(self.included_columns)
if self.include and self.raise_if_missing:
missing_columns = set(self.include) - total_columns.intersection(
set(self.include)
)
if missing_columns:
raise ValueError(
Expand All @@ -84,10 +84,10 @@ def _transform_pandas(self, df: pd.DataFrame):
self._validate(df)

included_columns = set(df)
if self.included_columns: # subset of included columns
included_columns = set(self.included_columns)
if self.include: # subset of included columns
included_columns = set(self.include)

columns_to_concat = list(included_columns - set(self.excluded_columns))
columns_to_concat = list(included_columns - set(self.exclude))
concatenated = df[columns_to_concat].to_numpy(dtype=self.dtype)
df = df.drop(columns=columns_to_concat)
try:
Expand All @@ -98,9 +98,18 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
return (
f"Concatenator(output_column_name={self.output_column_name}, "
f"include={self.included_columns}, "
f"exclude={self.excluded_columns}, "
f"dtype={self.dtype})"
)
default_values = {
"output_column_name": "concat_out",
"include": None,
"exclude": [],
"dtype": None,
"raise_if_missing": False,
}

non_default_arguments = []
for parameter, default_value in default_values.items():
value = getattr(self, parameter)
if value != default_value:
non_default_arguments.append(f"{parameter}={value}")

return f"{self.__class__.__name__}({', '.join(non_default_arguments)})"
Comment on lines +101 to +115
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this so the representation isn't too long.

24 changes: 13 additions & 11 deletions python/ray/data/preprocessors/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ def list_as_category(element):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"OrdinalEncoder(columns={self.columns}, stats={stats}, "
f"encode_lists={self.encode_lists})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"encode_lists={self.encode_lists!r})"
)


Expand Down Expand Up @@ -191,8 +190,10 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"OneHotEncoder(columns={self.columns}, stats={stats})"
return (
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"max_categories={self.max_categories!r})"
)


class MultiHotEncoder(Preprocessor):
Expand Down Expand Up @@ -277,8 +278,10 @@ def encode_list(element: list, *, name: str):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"MultiHotEncoder(columns={self.columns}, stats={stats})"
return (
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"max_categories={self.max_categories!r})"
)


class LabelEncoder(Preprocessor):
Expand Down Expand Up @@ -313,8 +316,7 @@ def column_label_encoder(s: pd.Series):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return f"LabelEncoder(label_column={self.label_column}, stats={stats})"
return f"{self.__class__.__name__}(label_column={self.label_column!r})"


class Categorizer(Preprocessor):
Expand Down Expand Up @@ -367,9 +369,9 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"<Categorizer columns={self.columns} dtypes={self.dtypes} stats={stats}>"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"dtypes={self.dtypes!r})"
)


Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/preprocessors/hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ def row_feature_hasher(row):

def __repr__(self):
return (
f"FeatureHasher(columns={self.columns}, num_features={self.num_features})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"num_features={self.num_features!r})"
)
8 changes: 2 additions & 6 deletions python/ray/data/preprocessors/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,9 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
stats = getattr(self, "stats_", None)
return (
f"SimpleImputer("
f"columns={self.columns}, "
f"strategy={self.strategy}, "
f"fill_value={self.fill_value}, "
f"stats={stats})"
f"{self.__class__.__name__}(columns={self.columns!r}, "
f"strategy={self.strategy!r}, fill_value={self.fill_value!r})"
)


Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/preprocessors/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ def _transform_pandas(self, df: pd.DataFrame):
return df

def __repr__(self):
return f"Normalizer(columns={self.columns}, norm={self.norm})>"
return (
f"{self.__class__.__name__}(columns={self.columns!r}, norm={self.norm!r})"
)
Loading