forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 0
/
leduc_holdem_cfr.py
65 lines (49 loc) · 2.02 KB
/
leduc_holdem_cfr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
''' An example of solve Leduc Hold'em with CFR
'''
import numpy as np
import rlcard
from rlcard.agents.cfr_agent import CFRAgent
from rlcard import models
from rlcard.utils.utils import set_global_seed
from rlcard.utils.logger import Logger
# Make environment and enable human mode
env = rlcard.make('leduc-holdem', allow_step_back=True)
eval_env = rlcard.make('leduc-holdem')
# Set the iterations numbers and how frequently we evaluate/save plot
evaluate_every = 100
save_plot_every = 1000
evaluate_num = 10000
episode_num = 10000000
# The paths for saving the logs and learning curves
root_path = './experiments/leduc_holdem_cfr_result/'
log_path = root_path + 'log.txt'
csv_path = root_path + 'performance.csv'
figure_path = root_path + 'figures/'
# Set a global seed
set_global_seed(0)
# Initilize CFR Agent
agent = CFRAgent(env)
agent.load() # If we have saved model, we first load the model
# Evaluate CFR against pre-trained NFSP
eval_env.set_agents([agent, models.load('leduc-holdem-nfsp').agents[0]])
# Init a Logger to plot the learning curve
logger = Logger(xlabel='iteration', ylabel='reward', legend='CFR on Leduc Holdem', log_path=log_path, csv_path=csv_path)
for episode in range(episode_num):
agent.train()
print('\rIteration {}'.format(episode), end='')
# Evaluate the performance. Play with NFSP agents.
if episode % evaluate_every == 0:
agent.save() # Save model
reward = 0
for eval_episode in range(evaluate_num):
_, payoffs = eval_env.run(is_training=False)
reward += payoffs[0]
logger.log('\n########## Evaluation ##########')
logger.log('Iteration: {} Average reward is {}'.format(episode, float(reward)/evaluate_num))
# Add point to logger
logger.add_point(x=env.timestep, y=float(reward)/evaluate_num)
# Make plot
if episode % save_plot_every == 0 and episode > 0:
logger.make_plot(save_path=figure_path+str(episode)+'.png')
# Make the final plot
logger.make_plot(save_path=figure_path+'final_'+str(episode)+'.png')