Skip to content

Commit

Permalink
Add support for DML execution provider (#1130)
Browse files Browse the repository at this point in the history
## Add support for DML execution provider

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
shaahji authored May 1, 2024
1 parent d5337d5 commit d0419e7
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path
from typing import Any, Dict, Union

from olive.hardware.accelerator import AcceleratorLookup, AcceleratorSpec, Device
from olive.hardware.accelerator import AcceleratorSpec, Device
from olive.model import ONNXModelHandler, PyTorchModelHandler
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
Expand Down Expand Up @@ -86,7 +86,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
),
"enable_cuda_graph": PassConfigParam(
type_=bool,
default_value=False,
default_value=None, # Explicitly setting to None to differentiate between user intent and default.
required=False,
description=(
"The model can use CUDA graph capture for CUDA execution provider. "
Expand All @@ -102,13 +102,12 @@ def validate_search_point(
if with_fixed_value:
search_point = self.config_at_search_point(search_point or {})
precision = search_point.get("precision")
device = (
Device.CPU
if self.accelerator_spec.execution_provider
in AcceleratorLookup.get_execution_providers_for_device(Device.CPU)
else Device.GPU
)
if precision == ModelBuilder.Precision.FP16 and device == Device.CPU:

# if device is GPU, but user choose CPU EP, the is_cpu should be True
if (precision == ModelBuilder.Precision.FP16) and not (
accelerator_spec.accelerator_type == Device.GPU
and accelerator_spec.execution_provider != "CPUExecutionProvider"
):
logger.info(
"FP16 is not supported on CPU. Valid precision + execution"
"provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA"
Expand Down Expand Up @@ -152,12 +151,12 @@ def _run_for_config(
else Path(resolve_onnx_path(output_model_path, model.onnx_file_name))
)

target_execution_provider = (
"cpu"
if self.accelerator_spec.execution_provider
in AcceleratorLookup.get_execution_providers_for_device(Device.CPU)
else "cuda"
)
if self.accelerator_spec.execution_provider == "DmlExecutionProvider":
target_execution_provider = "dml"
elif self.accelerator_spec.execution_provider == "CUDAExecutionProvider":
target_execution_provider = "cuda"
else:
target_execution_provider = "cpu"

# Select cache location based on priority
# HF_CACHE (HF >= v5) -> TRANSFORMERS_CACHE (HF < v5) -> local dir
Expand Down

0 comments on commit d0419e7

Please sign in to comment.