Skip to content

Commit

Permalink
Improve DQN Tutorial (#2934)
Browse files Browse the repository at this point in the history
Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
alperenunlu and svekars committed Jun 18, 2024
1 parent be898cb commit 0740801
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions intermediate_source/reinforcement_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
**Task**
The agent has to decide between two actions - moving the cart left or
Expand Down Expand Up @@ -83,7 +85,11 @@
plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)


######################################################################
Expand Down Expand Up @@ -397,7 +403,7 @@ def optimize_model():
# can produce better results if convergence is not observed.
#

if torch.cuda.is_available():
if torch.cuda.is_available() or torch.backends.mps.is_available():
num_episodes = 600
else:
num_episodes = 50
Expand Down

0 comments on commit 0740801

Please sign in to comment.