Skip to content

Commit

Permalink
import IPEX in torch config (#385)
Browse files Browse the repository at this point in the history
* import IPEX in torch config

* change message
  • Loading branch information
harborn authored Oct 16, 2023
1 parent d3a614f commit a93059b
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/raydp/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ class TorchConfig(RayTorchConfig):
def backend_cls(self):
return EnableCCLBackend

def ccl_import():
def libs_import():
"""try to import IPEX and oneCCL.
"""
try:
import intel_extension_for_pytorch
except ImportError:
raise ImportError(
"Please install intel_extension_for_pytorch"
)
try:
ccl_version = importlib_metadata.version("oneccl_bind_pt")
if ccl_version >= "1.12":
Expand All @@ -33,5 +41,5 @@ class EnableCCLBackend(_TorchBackend):

def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
for i in range(len(worker_group)):
worker_group.execute_single_async(i, ccl_import)
worker_group.execute_single_async(i, libs_import)
super().on_start(worker_group, backend_config)

0 comments on commit a93059b

Please sign in to comment.