For detailed instructions on getting started with PyTorch with DirectML, see GPU accelerated ML training.
Follow the steps below to get set up with PyTorch on DirectML.
-
Download and install Python 3.8 to 3.10.
-
Clone this repo.
-
Install prerequisites
pip install torchvision==0.14.0
pip install torch==1.13
pip install torch-directml
- (optional) Run
pip list
. The following packages should be installed:
torch 1.13.0
torch-directml 0.1.13.*
torchvision 0.14.0
- Create a DML Device and Test
import torch
import torch_directml
dml = torch_directml.device()
⚠️ Note that device creation has changed in torch-directml 1.13 from previous versions. The torch-directml backend is currently mapped to “PrivateUse1." The newtorch_directml.device()
API is a convenient wrapper for creating your tenors on the correct device.
The following sample models are included in this repo to help you get started. The sample includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.