diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 82424c80d7..0ae3ea9a90 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -9,6 +9,8 @@ This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent on the CartPole-v1 task from `Gymnasium `__. +You might find it helpful to read the original `Deep Q Learning (DQN) `__ paper + **Task** The agent has to decide between two actions - moving the cart left or @@ -83,7 +85,11 @@ plt.ion() # if GPU is to be used -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device( + "cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu" +) ###################################################################### @@ -397,7 +403,7 @@ def optimize_model(): # can produce better results if convergence is not observed. # -if torch.cuda.is_available(): +if torch.cuda.is_available() or torch.backends.mps.is_available(): num_episodes = 600 else: num_episodes = 50