Skip to content

Latest commit

 

History

History
 
 

1.8

PyTorch with DirectML Samples

For detailed instructions on getting started with PyTorch with DirectML, see GPU accelerated ML training.

Setup

Follow the steps below to get set up with PyTorch on DirectML.

  1. Download and install Python 3.8.

  2. Clone this repo.

  3. Install prerequisites

    pip install torchvision==0.9.0
    pip uninstall torch
    pip install pytorch-directml==1.8.0a0.dev220506

Note: The torchvision package automatically installs the torch==1.8.0 dependency, but this is not needed and will cause collisions with the pytorch-directml package. We must uninstall the torch package after installing requirements.

  1. (optional) Run pip list. The following packages should be installed:
pytorch-directml        1.8.0a0.dev220506
torchvision             0.9.0

Samples

The following sample models are included in this repo to help you get started. The includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.sample

External Links