forked from blackredscarf/pytorch-DQN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ddqn.py
133 lines (110 loc) · 4.5 KB
/
ddqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import random
import gym
import torch
from torch.optim import Adam
from tester import Tester
from buffer import ReplayBuffer
from config import Config
from core.util import get_class_attr_val
from model import DQN
from trainer import Trainer
class DDQNAgent:
def __init__(self, config: Config):
self.config = config
self.is_training = True
self.buffer = ReplayBuffer(self.config.max_buff)
self.model = DQN(self.config.state_dim, self.config.action_dim).cuda()
self.target_model = DQN(self.config.state_dim, self.config.action_dim).cuda()
self.target_model.load_state_dict(self.model.state_dict())
self.model_optim = Adam(self.model.parameters(), lr=self.config.learning_rate)
if self.config.use_cuda:
self.cuda()
def act(self, state, epsilon=None):
if epsilon is None: epsilon = self.config.epsilon_min
if random.random() > epsilon or not self.is_training:
state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
if self.config.use_cuda:
state = state.cuda()
q_value = self.model.forward(state)
action = q_value.max(1)[1].item()
else:
action = random.randrange(self.config.action_dim)
return action
def learning(self, fr):
s0, a, r, s1, done = self.buffer.sample(self.config.batch_size)
s0 = torch.tensor(s0, dtype=torch.float)
s1 = torch.tensor(s1, dtype=torch.float)
a = torch.tensor(a, dtype=torch.long)
r = torch.tensor(r, dtype=torch.float)
done = torch.tensor(done, dtype=torch.float)
if self.config.use_cuda:
s0 = s0.cuda()
s1 = s1.cuda()
a = a.cuda()
r = r.cuda()
done = done.cuda()
q_values = self.model(s0).cuda()
next_q_values = self.model(s1).cuda()
next_q_state_values = self.target_model(s1).cuda()
q_value = q_values.gather(1, a.unsqueeze(1)).squeeze(1)
next_q_value = next_q_state_values.gather(1, next_q_values.max(1)[1].unsqueeze(1)).squeeze(1)
expected_q_value = r + self.config.gamma * next_q_value * (1 - done)
# Notice that detach the expected_q_value
loss = (q_value - expected_q_value.detach()).pow(2).mean()
self.model_optim.zero_grad()
loss.backward()
self.model_optim.step()
if fr % self.config.update_tar_interval == 0:
self.target_model.load_state_dict(self.model.state_dict())
return loss.item()
def cuda(self):
self.model.cuda()
self.target_model.cuda()
def load_weights(self, model_path):
if model_path is None: return
self.model.load_state_dict(torch.load(model_path))
def save_model(self, output, tag=''):
torch.save(self.model.state_dict(), '%s/model_%s.pkl' % (output, tag))
def save_config(self, output):
with open(output + '/config.txt', 'w') as f:
attr_val = get_class_attr_val(self.config)
for k, v in attr_val.items():
f.write(str(k) + " = " + str(v) + "\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--train', dest='train', action='store_true', help='train model')
parser.add_argument('--env', default='CartPole-v0', type=str, help='gym environment')
parser.add_argument('--test', dest='test', action='store_true', help='test model')
parser.add_argument('--model_path', type=str, help='if test, import the model')
args = parser.parse_args()
# ddqn.py --train --env CartPole-v0
config = Config()
config.env = args.env
config.gamma = 0.99
config.epsilon = 1
config.epsilon_min = 0.01
config.eps_decay = 500
config.frames = 160000
config.use_cuda = True
config.learning_rate = 1e-3
config.max_buff = 1000
config.update_tar_interval = 100
config.batch_size = 128
config.print_interval = 200
config.log_interval = 200
config.win_reward = 198 # CartPole-v0
config.win_break = True
env = gym.make(config.env)
config.action_dim = env.action_space.n
config.state_dim = env.observation_space.shape[0]
agent = DDQNAgent(config)
if args.train:
trainer = Trainer(agent, env, config)
trainer.train()
elif args.test:
if args.model_path is None:
print('please add the model path:', '--model_path xxxx')
exit(0)
tester = Tester(agent, env, args.model_path)
tester.test()