From 88da2afc415ed685058633caf68e1daf61cec2d9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 22 Jul 2022 06:14:38 -0700 Subject: [PATCH 1/2] Finished --- dashboard/state_aggregator.py | 36 +++++++- python/ray/experimental/state/api.py | 99 +++++++++++++++++--- python/ray/experimental/state/common.py | 8 ++ python/ray/experimental/state/state_cli.py | 10 +- python/ray/tests/test_state_api.py | 102 ++++++++++++++++++++- 5 files changed, 235 insertions(+), 20 deletions(-) diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index c7788ab1027f..217b1c4a81ce 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -194,7 +194,9 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: result = [] for message in reply.actor_table_data: - data = self._message_to_dict(message=message, fields_to_decode=["actor_id"]) + data = self._message_to_dict( + message=message, fields_to_decode=["actor_id", "owner_id"] + ) result.append(data) result = self._filter(result, option.filters, ActorState, option.detail) @@ -443,6 +445,7 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: result = [] memory_table = memory_utils.construct_memory_table(worker_stats) + callsite_enabled = True for entry in memory_table.table: data = entry.as_dict() # `construct_memory_table` returns object_ref field which is indeed @@ -452,8 +455,23 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: del data["object_ref"] data["ip"] = data["node_ip_address"] del data["node_ip_address"] + + # If there's any "disabled" callsite, we consider the callsite collection + # is disabled. + if data["call_site"] == "disabled": + callsite_enabled = False result.append(data) + # Add callsite warnings if it is not configured. + callsite_warning = [] + if not callsite_enabled: + callsite_warning.append( + "Callsite is not being recorded. " + "To record callsite information for each ObjectRef created, set " + "env variable RAY_record_ref_creation_sites=1 during `ray start` " + "and and `ray.init`." + ) + result = self._filter(result, option.filters, ObjectState, option.detail) # Sort to make the output deterministic. result.sort(key=lambda entry: entry["object_id"]) @@ -462,6 +480,7 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: result=result, partial_failure_warning=partial_failure_warning, total=total_objects, + warnings=callsite_warning, ) async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: @@ -551,7 +570,10 @@ async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse } ) return SummaryApiResponse( - result=summary, partial_failure_warning=result.partial_failure_warning + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=result.warnings, ) async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse: @@ -565,7 +587,10 @@ async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiRespons } ) return SummaryApiResponse( - result=summary, partial_failure_warning=result.partial_failure_warning + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=result.warnings, ) async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse: @@ -579,7 +604,10 @@ async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiRespon } ) return SummaryApiResponse( - result=summary, partial_failure_warning=result.partial_failure_warning + total=result.total, + result=summary, + partial_failure_warning=result.partial_failure_warning, + warnings=result.warnings, ) def _message_to_dict( diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index 6bd84b70d2eb..fb54cca23545 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -218,7 +218,7 @@ def _make_http_get_request( f"Error: {response['msg']}" ) - # Dictionary of `ListApiResponse` + # Dictionary of `ListApiResponse` or `SummaryApiResponse` return response["data"]["result"] def get( @@ -311,22 +311,22 @@ def get( assert len(result) == 1 return result[0] - def _print_list_api_warning(self, resource: StateResource, list_api_response: dict): + def _print_api_warning(self, resource: StateResource, api_response: dict): """Print the API warnings. Args: resource: Resource names, i.e. 'jobs', 'actors', 'nodes', see `StateResource` for details. - list_api_response: The dictionarified `ListApiResponse`. + api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. """ # Print warnings if anything was given. - warning_msgs = list_api_response.get("partial_failure_warning", None) + warning_msgs = api_response.get("partial_failure_warning", None) if warning_msgs: warnings.warn(warning_msgs) # Print warnings if data is truncated. - data = list_api_response["result"] - total = list_api_response["total"] + data = api_response["result"] + total = api_response["total"] if total > len(data): warnings.warn( ( @@ -337,8 +337,42 @@ def _print_list_api_warning(self, resource: StateResource, list_api_response: di ), ) + # Print the additional warnings. + warnings_to_print = api_response.get("warnings", []) + if warnings_to_print: + for warning_to_print in warnings_to_print: + warnings.warn(warning_to_print) + + def _raise_on_missing_output(self, resource: StateResource, api_response: dict): + """Raise an exception when the API resopnse contains a missing output. + + Output can be missing if (1) Failures on some of data source queries (e.g., + `ray list tasks` queries all raylets, and if some of quries fail, it will + contain missing output. If all quries fail, it will just fail). (2) Data + is truncated because the output is too large. + + Args: + resource: Resource names, i.e. 'jobs', 'actors', 'nodes', + see `StateResource` for details. + api_response: The dictionarified `ListApiResponse` or `SummaryApiResponse`. + """ + warning_msgs = api_response.get("partial_failure_warning", None) + # TODO(sang) raise an exception on truncation after + # https://github.com/ray-project/ray/pull/26801. + if warning_msgs: + raise RayStateApiException( + f"Failed to retrieve all {resource.value} from the cluster. " + f"It can happen when some of {resource.value} information is not " + "reachable the returned data is truncated because it is too large. " + "To allow having missing output, set `raise_on_missing_output=False`. " + ) + def list( - self, resource: StateResource, options: ListApiOptions, _explain: bool = False + self, + resource: StateResource, + options: ListApiOptions, + raise_on_missing_output: bool, + _explain: bool = False, ) -> Union[Dict, List]: """List resources states @@ -346,6 +380,10 @@ def list( resource: Resource names, i.e. 'jobs', 'actors', 'nodes', see `StateResource` for details. options: List options. See `ListApiOptions` for details. + raise_on_missing_output: Raise an exception if the output has missing data. + Output can have missing data if (1) there's a partial network failure + when the source is distributed. (2) data is truncated + because it is too large. _explain: Print the API information such as API latency or failed query information. @@ -366,8 +404,10 @@ def list( timeout=options.timeout, _explain=_explain, ) + if raise_on_missing_output: + self._raise_on_missing_output(resource, list_api_response) if _explain: - self._print_list_api_warning(resource, list_api_response) + self._print_api_warning(resource, list_api_response) return list_api_response["result"] def summary( @@ -375,6 +415,7 @@ def summary( resource: SummaryResource, *, options: SummaryApiOptions, + raise_on_missing_output: bool, _explain: bool = False, ) -> Dict: """Summarize resources states @@ -384,6 +425,12 @@ def summary( see `SummaryResource` for details. options: summary options. See `SummaryApiOptions` for details. A dictionary of queried result from `SummaryApiResponse`, + raise_on_missing_output: Raise an exception if the output has missing data. + Output can have missing data if (1) there's a partial network failure + when the source is distributed. (2) data is truncated + because it is too large. + _explain: Print the API information such as API + latency or failed query information. Raises: This doesn't catch any exceptions raised when the underlying request @@ -392,14 +439,17 @@ def summary( """ params = {"timeout": options.timeout} endpoint = f"/api/v0/{resource.value}/summarize" - list_api_response = self._make_http_get_request( + summary_api_response = self._make_http_get_request( endpoint=endpoint, params=params, timeout=options.timeout, _explain=_explain, ) - result = list_api_response["result"] - return result["node_id_to_summary"] + if raise_on_missing_output: + self._raise_on_missing_output(resource, summary_api_response) + # TODO(sang): Add warning after + # # https://github.com/ray-project/ray/pull/26801 is merged. + return summary_api_response["result"]["node_id_to_summary"] """ @@ -518,13 +568,18 @@ def list_actors( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( StateResource.ACTORS, options=ListApiOptions( - limit=limit, timeout=timeout, filters=filters, detail=detail + limit=limit, + timeout=timeout, + filters=filters, + detail=detail, ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -535,6 +590,7 @@ def list_placement_groups( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -542,6 +598,7 @@ def list_placement_groups( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -552,6 +609,7 @@ def list_nodes( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -559,6 +617,7 @@ def list_nodes( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -569,6 +628,7 @@ def list_jobs( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -576,6 +636,7 @@ def list_jobs( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -586,6 +647,7 @@ def list_workers( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -593,6 +655,7 @@ def list_workers( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -603,6 +666,7 @@ def list_tasks( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -610,6 +674,7 @@ def list_tasks( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -620,6 +685,7 @@ def list_objects( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -627,6 +693,7 @@ def list_objects( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -637,6 +704,7 @@ def list_runtime_envs( limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, detail: bool = False, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).list( @@ -644,6 +712,7 @@ def list_runtime_envs( options=ListApiOptions( limit=limit, timeout=timeout, filters=filters, detail=detail ), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -759,11 +828,13 @@ def list_logs( def summarize_tasks( address: str = None, timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).summary( SummaryResource.TASKS, options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -771,11 +842,13 @@ def summarize_tasks( def summarize_actors( address: str = None, timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).summary( SummaryResource.ACTORS, options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) @@ -783,10 +856,12 @@ def summarize_actors( def summarize_objects( address: str = None, timeout: int = DEFAULT_RPC_TIMEOUT, + raise_on_missing_output: bool = True, _explain: bool = False, ): return StateApiClient(address=address).summary( SummaryResource.OBJECTS, options=SummaryApiOptions(timeout=timeout), + raise_on_missing_output=raise_on_missing_output, _explain=_explain, ) diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index dc8f3224092b..a516e06dccb0 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -380,6 +380,8 @@ class ListApiResponse: # availability of data because ray's state information is # not replicated. partial_failure_warning: str = "" + # A list of warnings to print. + warnings: Optional[List[str]] = None """ @@ -608,5 +610,11 @@ class StateSummary: @dataclass(init=True) class SummaryApiResponse: + # Total number of the resource from the cluster. + # Note that this value can be larger than `result` + # because `result` can be truncated. + total: int result: StateSummary = None partial_failure_warning: str = "" + # A list of warnings to print. + warnings: Optional[List[str]] = None diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index fd727b2974f1..89e03695d942 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -472,7 +472,12 @@ def list( ) # If errors occur, exceptions will be thrown. Empty data indicate successful query. - data = client.list(resource, options=options, _explain=_should_explain(format)) + data = client.list( + resource, + options=options, + raise_on_missing_output=False, + _explain=_should_explain(format), + ) # Print data to console. print( @@ -500,6 +505,7 @@ def task_summary(ctx, timeout: float, address: str): summarize_tasks( address=address, timeout=timeout, + raise_on_missing_output=False, _explain=True, ), resource=StateResource.TASKS, @@ -517,6 +523,7 @@ def actor_summary(ctx, timeout: float, address: str): summarize_actors( address=address, timeout=timeout, + raise_on_missing_output=False, _explain=True, ), resource=StateResource.ACTORS, @@ -534,6 +541,7 @@ def object_summary(ctx, timeout: float, address: str): summarize_objects( address=address, timeout=timeout, + raise_on_missing_output=False, _explain=True, ), ) diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 6141c5ffa155..815614063f6b 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -66,6 +66,7 @@ list_runtime_envs, list_tasks, list_workers, + summarize_tasks, StateApiClient, ) from ray.experimental.state.common import ( @@ -1914,11 +1915,11 @@ def f(): cluster.remove_node(n, allow_graceful=False) with pytest.warns(UserWarning): - list_tasks(_explain=True) + list_tasks(raise_on_missing_output=False, _explain=True) # Make sure when _explain == False, warning is not printed. with pytest.warns(None) as record: - list_tasks(_explain=False) + list_tasks(raise_on_missing_output=False, _explain=False) assert len(record) == 0 @@ -1946,7 +1947,7 @@ def f(): def verify(): with pytest.warns(None) as record: - list_tasks(_explain=True, timeout=5) + list_tasks(raise_on_missing_output=False, _explain=True, timeout=5) return len(record) == 1 wait_for_condition(verify) @@ -2423,6 +2424,101 @@ def verify(): wait_for_condition(verify) +@pytest.mark.parametrize("callsite_enabled", [True, False]) +def test_callsite_warning(callsite_enabled, monkeypatch, shutdown_only): + # Set environment + with monkeypatch.context() as m: + m.setenv("RAY_record_ref_creation_sites", str(int(callsite_enabled))) + ray.init() + + a = ray.put(1) # noqa + + runner = CliRunner() + wait_for_condition(lambda: len(list_objects()) > 0) + + with pytest.warns(None) as record: + result = runner.invoke(cli_list, ["objects"]) + assert result.exit_code == 0 + + if callsite_enabled: + assert len(record) == 0 + else: + assert len(record) == 1 + assert "RAY_record_ref_creation_sites=1" in str(record[0].message) + + +def test_raise_on_missing_output_partial_failures(monkeypatch, ray_start_cluster): + """ + Verify when there are network partial failures, + state API raises an exception when `raise_on_missing_output=True`. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "NodeManagerService.grpc_server.GetTasksInfo=10000000:10000000", + ) + cluster.add_node(num_cpus=2) + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + + runner = CliRunner() + + # Verify + def verify(): + # Verify when raise_on_missing_output=True, it raises an exception. + try: + list_tasks(_explain=True, timeout=3) + except RayStateApiException as e: + assert "Failed to retrieve all tasks from the cluster." in str(e) + else: + assert False + + try: + summarize_tasks(_explain=True, timeout=3) + except RayStateApiException as e: + assert "Failed to retrieve all tasks from the cluster." in str(e) + else: + assert False + + # Verify when raise_on_missing_output=False, it prints warnings. + with pytest.warns(None) as record: + list_tasks(raise_on_missing_output=False, _explain=True, timeout=3) + assert len(record) == 1 + + # TODO(sang): Add warning after https://github.com/ray-project/ray/pull/26801 + # is merged. + # with pytest.warns(None) as record: + # summarize_tasks(raise_on_missing_output=False, _explain=True, timeout=3) + # assert len(record) == 1 + + # Verify when CLI is used, exceptions are not raised. + with pytest.warns(None) as record: + result = runner.invoke(cli_list, ["tasks", "--timeout=3"]) + assert len(record) == 1 + assert result.exit_code == 0 + + # TODO(sang): Add warning after https://github.com/ray-project/ray/pull/26801 + # is merged. + # Verify summary CLI also doesn't raise an exception. + # with pytest.warns(None) as record: + # result = runner.invoke(task_summary, ["--timeout=3"]) + # assert result.exit_code == 0 + # assert len(record) == 1 + return True + + wait_for_condition(verify) + + if __name__ == "__main__": import os import sys From 742ca583c2a86790dc0e727f6d206e5bd8db687e Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 24 Jul 2022 04:21:52 -0700 Subject: [PATCH 2/2] Addressed code review. --- dashboard/state_aggregator.py | 11 ++++------- dashboard/tests/test_dashboard.py | 4 +++- python/ray/experimental/state/api.py | 11 ++++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index 217b1c4a81ce..f9cb45956309 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -5,6 +5,8 @@ from itertools import islice from typing import List, Tuple +from ray._private.ray_constants import env_integer + import ray.dashboard.memory_utils as memory_utils import ray.dashboard.utils as dashboard_utils from ray._private.utils import binary_to_hex @@ -445,7 +447,6 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: result = [] memory_table = memory_utils.construct_memory_table(worker_stats) - callsite_enabled = True for entry in memory_table.table: data = entry.as_dict() # `construct_memory_table` returns object_ref field which is indeed @@ -455,21 +456,17 @@ async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: del data["object_ref"] data["ip"] = data["node_ip_address"] del data["node_ip_address"] - - # If there's any "disabled" callsite, we consider the callsite collection - # is disabled. - if data["call_site"] == "disabled": - callsite_enabled = False result.append(data) # Add callsite warnings if it is not configured. callsite_warning = [] + callsite_enabled = env_integer("RAY_record_ref_creation_sites", 0) if not callsite_enabled: callsite_warning.append( "Callsite is not being recorded. " "To record callsite information for each ObjectRef created, set " "env variable RAY_record_ref_creation_sites=1 during `ray start` " - "and and `ray.init`." + "and `ray.init`." ) result = self._filter(result, option.filters, ObjectState, option.detail) diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index c8d1e45a66cc..d8ca31714fcf 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -961,7 +961,9 @@ def test_dashboard_requests_fail_on_missing_deps(ray_start_with_dashboard): with pytest.raises(ServerUnavailable): client = StateApiClient(address=DEFAULT_DASHBOARD_ADDRESS) - response = client.list(StateResource.NODES, options=ListApiOptions()) + response = client.list( + StateResource.NODES, options=ListApiOptions(), raise_on_missing_output=False + ) # Response should not be populated assert response is None diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index fb54cca23545..b7ddb7f2d24b 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -363,7 +363,7 @@ def _raise_on_missing_output(self, resource: StateResource, api_response: dict): raise RayStateApiException( f"Failed to retrieve all {resource.value} from the cluster. " f"It can happen when some of {resource.value} information is not " - "reachable the returned data is truncated because it is too large. " + "reachable or the returned data is truncated because it is too large. " "To allow having missing output, set `raise_on_missing_output=False`. " ) @@ -380,10 +380,11 @@ def list( resource: Resource names, i.e. 'jobs', 'actors', 'nodes', see `StateResource` for details. options: List options. See `ListApiOptions` for details. - raise_on_missing_output: Raise an exception if the output has missing data. - Output can have missing data if (1) there's a partial network failure - when the source is distributed. (2) data is truncated - because it is too large. + raise_on_missing_output: When True, raise an exception if the output + is incomplete. Output can be incomplete if + (1) there's a partial network failure when the source is distributed. + (2) data is truncated because it is too large. + Set it to False to avoid throwing an exception on missing data. _explain: Print the API information such as API latency or failed query information.