Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMD GPU support #1546

Merged
merged 11 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
title: How to accelerate training
- local: onnxruntime/usage_guides/gpu
title: Accelerated inference on NVIDIA GPUs
- local: onnxruntime/usage_guides/amdgpu
title: Accelerated inference on AMD GPUs
title: How-to guides
isExpanded: false
- sections:
Expand Down
124 changes: 124 additions & 0 deletions docs/source/onnxruntime/usage_guides/amdgpu.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Accelerated inference on AMD GPUs supported by ROCm

By default, ONNX Runtime runs inference on CPU devices. However, it is possible to place supported operations on an AMD Instinct GPU, while leaving any unsupported ones on CPU. In most cases, this allows costly operations to be placed on GPU and significantly accelerate inference.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could clarify that we have tested on Instinct GPUs, but that support matrix is https://rocm.docs.amd.com/en/latest/release/gpu_os_support.html (unless ROCMExecutionProvider explicitely requires Instinct? In which case we can give a ref)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


Our testing primarily involved AMD Instinct GPUs, and for specific GPU compatibility, refer to the support matrix available at AMD ROCm GPU OS support matrix."
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

This guide will show you how to run inference on the `ROCMExecutionProvider` execution provider that ONNX Runtime supports for AMD GPUs.

## Installation
The following setup installs the ONNX Runtime support with ROCM Execution Provider with ROCm 5.7.

#### 1. ROCm Installation

To install ROCM 5.7, please follow the [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html).

#### 2. PyTorch Installation with ROCm Support
Optimum ONNX Runtime integration relies on some functionalities of Transformers that require PyTorch. For now, we recommend to use Pytorch compiled against RoCm 5.7, that can be installed following [PyTorch installation guide](https://pytorch.org/get-started/locally/):

mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
```bash
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
```

<Tip>
For docker installation, the following base image is recommended: `rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1`
</Tip>

### 3. ONNX Runtime installation with ROCm Execution Provider

```bash
# pre-requisites
pip install -U pip
pip install cmake onnx
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

# Install ONNXRuntime from source
git clone --recursive https://github.com/ROCmSoftwarePlatform/onnxruntime.git
git checkout rocm5.7_internal_testing_eigen-3.4.zip_hash
cd onnxruntime

./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm
pip install build/Linux/Release/dist/*
```

<Tip>
To avoid conflicts between `onnxruntime` and `onnxruntime-rocm`, make sure the package `onnxruntime` is not installed by running `pip uninstall onnxruntime` prior to installing `onnxruntime-rocm`.
</Tip>

### Checking the ROCm installation is successful

Before going further, run the following sample code to check whether the install was successful:

```python
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
>>> from transformers import AutoTokenizer

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "philschmid/tiny-bert-sst2-distilled",
... export=True,
... provider="ROCMExecutionProvider",
... )

>>> tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
>>> inputs = tokenizer("expectations were low, actual enjoyment was high", return_tensors="pt", padding=True)

>>> outputs = ort_model(**inputs)
>>> assert ort_model.providers == ["ROCMExecutionProvider", "CPUExecutionProvider"]
```

In case this code runs gracefully, congratulations, the installation is successfull! If you encounter the following error or similar,

```
ValueError: Asked to use ROCMExecutionProvider as an ONNX Runtime execution provider, but the available execution providers are ['CPUExecutionProvider'].
```

then something is wrong with the ROCM or ONNX Runtime installation.

### Use ROCM Execution Provider with ORT models

For ORT models, the use is straightforward. Simply specify the `provider` argument in the `ORTModel.from_pretrained()` method. Here's an example:

```python
>>> from optimum.onnxruntime import ORTModelForSequenceClassification

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "distilbert-base-uncased-finetuned-sst-2-english",
... export=True,
... provider="ROCMExecutionProvider",
... )
```

The model can then be used with the common 🤗 Transformers API for inference and evaluation, such as [pipelines](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/pipelines).
When using Transformers pipeline, note that the `device` argument should be set to perform pre- and post-processing on GPU, following the example below:

```python
>>> from optimum.pipelines import pipeline
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

>>> pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
>>> result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
>>> print(result) # doctest: +IGNORE_RESULT
# printing: [{'label': 'POSITIVE', 'score': 0.9997727274894c714}]
```

Additionally, you can pass the session option `log_severity_level = 0` (verbose), to check whether all nodes are indeed placed on the ROCM execution provider or not:

```python
>>> import onnxruntime

>>> session_options = onnxruntime.SessionOptions()
>>> session_options.log_severity_level = 0

>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "distilbert-base-uncased-finetuned-sst-2-english",
... export=True,
... provider="ROCMExecutionProvider",
... session_options=session_options
... )
```

### Observed time gains

Coming soon!
28 changes: 28 additions & 0 deletions docs/source/onnxruntime/usage_guides/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ To use `ORTTrainer` or `ORTSeq2SeqTrainer`, you need to install ONNX Runtime Tra
To set up the environment, we __strongly recommend__ you install the dependencies with Docker to ensure that the versions are correct and well
configured. You can find dockerfiles with various combinations [here](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training/docker).

#### Setup for NVIDIA GPU

Here below we take the installation of `onnxruntime-training 1.14.0` as an example:

* If you want to install `onnxruntime-training 1.14.0` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort1.14.0-cu116):
Expand All @@ -80,6 +82,32 @@ And run post-installation configuration:
python -m torch_ort.configure
```

#### Setup for AMD GPU

Here below we take the installation of `onnxruntime-training` nightly as an example:

* If you want to install `onnxruntime-training` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort-nightly-rocm57):

```bash
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
```

* If you want to install the dependencies beyond in a local Python environment. You can pip install them once you have [ROCM 5.7](https://rocmdocs.amd.com/en/latest/deploy/linux/quick_start.html) well installed.

```bash
pip install onnx ninja
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2
```

And run post-installation configuration:

```bash
python -m torch_ort.configure
```

### Install Optimum

You can install Optimum via pypi:
Expand Down
43 changes: 43 additions & 0 deletions examples/onnxruntime/training/docker/Dockerfile-ort-nightly-rocm57
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Use rocm image
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
CMD rocm-smi

# Ignore interactive questions during `docker build`
ENV DEBIAN_FRONTEND noninteractive

# Versions
# available options 3.10
ARG PYTHON_VERSION=3.10

# Bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]

# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN apt-get autoremove -y

ARG PYTHON_EXE=/opt/conda/envs/py_$PYTHON_VERSION/bin/python

# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install -U pip
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece --no-cache-dir
RUN $PYTHON_EXE -m pip install deepspeed --no-cache-dir
RUN conda install -y mpi4py

# PyTorch
RUN $PYTHON_EXE -m pip install onnx ninja

# ORT Module
RUN $PYTHON_EXE -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
RUN $PYTHON_EXE -m pip install torch-ort
RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
RUN $PYTHON_EXE -m torch_ort.configure

WORKDIR .

CMD ["/bin/bash"]
12 changes: 8 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,20 @@ def to(self, device: Union[torch.device, str, int]):
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

if device.type == "cuda" and self._use_io_binding is False:
self.device = device
provider = get_provider_for_device(self.device)
validate_provider_availability(provider) # raise error if the provider is not available

# IOBinding is only supported for CPU and CUDA Execution Providers.
if device.type == "cuda" and self._use_io_binding is False and provider == "CUDAExecutionProvider":
self.use_io_binding = True
logger.info(
"use_io_binding was set to False, setting it to True because it can provide a huge speedup on GPUs. "
"It is possible to disable this feature manually by setting the use_io_binding attribute back to False."
)

self.device = device
provider = get_provider_for_device(self.device)
validate_provider_availability(provider) # raise error if the provider is not available
if provider == "ROCMExecutionProvider":
self.use_io_binding = False

self.model.set_providers([provider], provider_options=[provider_options])
self.providers = self.model.get_providers()
Expand Down
13 changes: 10 additions & 3 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def _is_gpu_available():
Checks if a gpu is available.
"""
available_providers = ort.get_available_providers()
if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available():
if (
"CUDAExecutionProvider" in available_providers or "ROCMExecutionProvider" in available_providers
) and torch.cuda.is_available():
return True
else:
return False
Expand Down Expand Up @@ -184,7 +186,7 @@ def get_device_for_provider(provider: str, provider_options: Dict) -> torch.devi
"""
Gets the PyTorch device (CPU/CUDA) associated with an ONNX Runtime provider.
"""
if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider", "ROCMExecutionProvider"]:
return torch.device(f"cuda:{provider_options['device_id']}")
else:
return torch.device("cpu")
Expand All @@ -194,7 +196,12 @@ def get_provider_for_device(device: torch.device) -> str:
"""
Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
"""
return "CUDAExecutionProvider" if device.type.lower() == "cuda" else "CPUExecutionProvider"
if device.type.lower() == "cuda":
if "CUDAExecutionProvider" in ort.get_available_providers():
return "CUDAExecutionProvider"
else:
return "ROCMExecutionProvider"
return "CPUExecutionProvider"


def parse_device(device: Union[torch.device, str, int]) -> Tuple[torch.device, Dict]:
Expand Down
11 changes: 11 additions & 0 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)


def require_ort_rocm(test_case):
"""Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime."""
import onnxruntime as ort

providers = ort.get_available_providers()

return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")(
test_case
)


def require_hf_token(test_case):
"""
Decorator marking a test that requires huggingface hub token.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ known-first-party = ["optimum"]
[tool.pytest.ini_options]
markers = [
"gpu_test",
"cuda_ep_test",
"trt_ep_test",
"rocm_ep_test",
"tensorflow_test",
"timm_test",
"run_in_series",
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ COPY . /workspace/optimum
RUN pip install /workspace/optimum[onnxruntime-gpu,tests]

ENV TEST_LEVEL=1
CMD pytest onnxruntime/test_*.py --durations=0 -s -vvvvv -m gpu_test
CMD pytest onnxruntime/test_*.py --durations=0 -s -vvvvv -m cuda_ep_test -m trt_ep_test
Loading
Loading