diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 387d7b9db1d0..e9e878a186cb 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -1,6 +1,7 @@ # isort: skip_file import logging import os +import sys logger = logging.getLogger(__name__) @@ -115,7 +116,7 @@ def _configure_system(): import ray._raylet # noqa: E402 -from ray._raylet import ( # noqa: E402 +from ray._raylet import ( # noqa: E402,F401 ActorClassID, ActorID, NodeID, @@ -134,7 +135,7 @@ def _configure_system(): _config = _Config() -from ray._private.state import ( # noqa: E402 +from ray._private.state import ( # noqa: E402,F401 nodes, timeline, cluster_resources, @@ -162,23 +163,21 @@ def _configure_system(): # We import ray.actor because some code is run in actor.py which initializes # some functions in the worker. import ray.actor # noqa: E402,F401 -from ray.actor import method # noqa: E402 +from ray.actor import method # noqa: E402,F401 # TODO(qwang): We should remove this exporting in Ray2.0. -from ray.cross_language import java_function, java_actor_class # noqa: E402 -from ray.runtime_context import get_runtime_context # noqa: E402 -from ray import autoscaler # noqa:E402 -from ray import data # noqa: E402,F401 +from ray.cross_language import java_function, java_actor_class # noqa: E402,F401 +from ray.runtime_context import get_runtime_context # noqa: E402,F401 +from ray import autoscaler # noqa: E402,F401 from ray import internal # noqa: E402,F401 -from ray import util # noqa: E402 +from ray import util # noqa: E402,F401 from ray import _private # noqa: E402,F401 -from ray import workflow # noqa: E402,F401 # We import ClientBuilder so that modules can inherit from `ray.ClientBuilder`. -from ray.client_builder import client, ClientBuilder # noqa: E402 +from ray.client_builder import client, ClientBuilder # noqa: E402,F401 -class _DeprecationWrapper(object): +class _DeprecationWrapper: def __init__(self, name, real_worker): self._name = name self._real_worker = real_worker @@ -201,18 +200,23 @@ def __getattr__(self, attr): serialization = _DeprecationWrapper("serialization", ray._private.serialization) state = _DeprecationWrapper("state", ray._private.state) + +_subpackages = [ + "data", + "workflow", +] + __all__ = [ "__version__", "_config", "get_runtime_context", "actor", - "available_resources", "autoscaler", + "available_resources", "cancel", "client", "ClientBuilder", "cluster_resources", - "data", "get", "get_actor", "get_gpu_ids", @@ -237,7 +241,7 @@ def __getattr__(self, attr): "LOCAL_MODE", "SCRIPT_MODE", "WORKER_MODE", -] +] + _subpackages # ID types __all__ += [ @@ -255,6 +259,20 @@ def __getattr__(self, attr): "PlacementGroupID", ] +if sys.version_info < (3, 7): + # TODO(Clark): Remove this one we drop Python 3.6 support. + from ray import data # noqa: F401 + from ray import workflow # noqa: F401 +else: + # Delay importing of expensive, isolated subpackages. + def __getattr__(name: str): + import importlib + + if name in _subpackages: + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + del os del logging +del sys diff --git a/python/ray/tests/test_runtime_env_working_dir_4.py b/python/ray/tests/test_runtime_env_working_dir_4.py index f9a7fd742e75..5d767bba33d8 100644 --- a/python/ray/tests/test_runtime_env_working_dir_4.py +++ b/python/ray/tests/test_runtime_env_working_dir_4.py @@ -95,7 +95,7 @@ def check(self): # this delay will make worker start slow and time out "testing_asio_delay_us": "InternalKVGcsService.grpc_server" ".InternalKVGet=2000000:2000000", - "worker_register_timeout_seconds": 1, + "worker_register_timeout_seconds": 0.5, }, }, { @@ -105,7 +105,7 @@ def check(self): # this delay will make worker start slow and time out "testing_asio_delay_us": "InternalKVGcsService.grpc_server" ".InternalKVGet=2000000:2000000", - "worker_register_timeout_seconds": 1, + "worker_register_timeout_seconds": 0.5, }, }, ], diff --git a/python/ray/tests/test_top_level_api.py b/python/ray/tests/test_top_level_api.py index 5eb02a55ee8f..6d226be64003 100644 --- a/python/ray/tests/test_top_level_api.py +++ b/python/ray/tests/test_top_level_api.py @@ -1,4 +1,7 @@ from inspect import getmembers, isfunction, ismodule +import sys + +import pytest import ray @@ -36,9 +39,13 @@ def test_api_functions(): "get_runtime_context", ] + IMPL_FUNCTIONS = ["__getattr__"] + functions = getmembers(ray, isfunction) function_names = [f[0] for f in functions] - assert set(function_names) == set(PYTHON_API + OTHER_ALLOWED_FUNCTIONS) + assert set(function_names) == set( + PYTHON_API + OTHER_ALLOWED_FUNCTIONS + IMPL_FUNCTIONS + ) def test_non_ray_modules(): @@ -47,6 +54,43 @@ def test_non_ray_modules(): assert "ray" in str(mod), f"Module {mod} should not be reachable via ray.{name}" +def test_dynamic_subpackage_import(): + # Test that subpackages are dynamically imported and properly cached. + + # ray.data + assert "ray.data" not in sys.modules + ray.data + # Check that the package is cached. + assert "ray.data" in sys.modules + + # ray.workflow + assert "ray.workflow" not in sys.modules + ray.workflow + # Check that the package is cached. + assert "ray.workflow" in sys.modules + + +def test_dynamic_subpackage_missing(): + # Test nonexistent subpackage dynamic attribute access and imports raise expected + # errors. + + # Test that nonexistent subpackage attribute access raises an AttributeError. + with pytest.raises(AttributeError): + ray.foo # noqa:F401 + + # Test that nonexistent subpackage import raises an ImportError. + with pytest.raises(ImportError): + from ray.foo import bar # noqa:F401 + + +def test_dynamic_subpackage_fallback_only(): + # Test that the __getattr__ dynamic + assert "ray.autoscaler" in sys.modules + assert ray.__getattribute__("autoscaler") is ray.autoscaler + with pytest.raises(AttributeError): + ray.__getattr__("autoscaler") + + def test_for_strings(): strings = getmembers(ray, lambda obj: isinstance(obj, str)) for string, _ in strings: @@ -55,9 +99,7 @@ def test_for_strings(): if __name__ == "__main__": - import pytest import os - import sys if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))