diff --git a/python/raydp/torch/config.py b/python/raydp/torch/config.py index 19965734..b4d705c4 100644 --- a/python/raydp/torch/config.py +++ b/python/raydp/torch/config.py @@ -16,7 +16,15 @@ class TorchConfig(RayTorchConfig): def backend_cls(self): return EnableCCLBackend -def ccl_import(): +def libs_import(): + """try to import IPEX and oneCCL. + """ + try: + import intel_extension_for_pytorch + except ImportError: + raise ImportError( + "Please install intel_extension_for_pytorch" + ) try: ccl_version = importlib_metadata.version("oneccl_bind_pt") if ccl_version >= "1.12": @@ -33,5 +41,5 @@ class EnableCCLBackend(_TorchBackend): def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): for i in range(len(worker_group)): - worker_group.execute_single_async(i, ccl_import) + worker_group.execute_single_async(i, libs_import) super().on_start(worker_group, backend_config)