Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Nov 24, 2023
1 parent 0fe7545 commit 1551292
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def get_provider_for_device(device: torch.device) -> str:
Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
"""
if device.type.lower() == "cuda":
if "CUDAExecutionProvider" in ort.get_available_providers():
return "CUDAExecutionProvider"
else:
if "ROCMExecutionProvider" in ort.get_available_providers():
return "ROCMExecutionProvider"
else:
return "CUDAExecutionProvider"
return "CPUExecutionProvider"


Expand Down

0 comments on commit 1551292

Please sign in to comment.