For detailed instructions on getting started with PyTorch with DirectML, see GPU accelerated ML training.
Follow the steps below to get set up with PyTorch on DirectML.
-
Download and install Python 3.8.
-
Clone this repo.
-
Install prerequisites
pip install torchvision==0.9.0
pip uninstall torch
pip install pytorch-directml==1.8.0a0.dev220506
Note: The torchvision package automatically installs the torch==1.8.0 dependency, but this is not needed and will cause collisions with the pytorch-directml package. We must uninstall the torch package after installing requirements.
- (optional) Run
pip list
. The following packages should be installed:
pytorch-directml 1.8.0a0.dev220506
torchvision 0.9.0
The following sample models are included in this repo to help you get started. The includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.sample