For detailed instructions on getting started with PyTorch with DirectML, see GPU accelerated ML training.
Follow the steps below to get set up with PyTorch on DirectML.
-
Download and install Python 3.8 to 3.10.
-
Clone this repo.
-
Install torch-directml
⚠️ Since torch-directml 0.1.13.1.*, torch and torchvision will be installed as dependencies
pip install torch-directml
- Create a DML Device and Test
import torch
import torch_directml
dml = torch_directml.device()
⚠️ Note that device creation has changed in torch-directml 0.1.13 from previous versions. The torch-directml backend is currently mapped to “PrivateUse1." The newtorch_directml.device()
API is a convenient wrapper for creating your tenors on the correct device.
The following sample models are included in this repo to help you get started. The sample includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.