forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_rl.py
92 lines (72 loc) · 3.41 KB
/
run_rl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
''' An example of training a reinforcement learning agent on the environments in RLCard
'''
import os
import argparse
import torch
import rlcard
from rlcard.agents import RandomAgent
from rlcard.utils import get_device, set_seed, tournament, reorganize, Logger, plot_curve
def train(args):
# Check whether gpu is available
device = get_device()
# Seed numpy, torch, random
set_seed(args.seed)
# Make the environment with seed
env = rlcard.make(args.env, config={'seed': args.seed})
# Initialize the agent and use random agents as opponents
if args.algorithm == 'dqn':
from rlcard.agents import DQNAgent
agent = DQNAgent(num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device)
elif args.algorithm == 'nfsp':
from rlcard.agents import NFSPAgent
agent = NFSPAgent(num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device)
agents = [agent]
for _ in range(1, env.num_players):
agents.append(RandomAgent(num_actions=env.num_actions))
env.set_agents(agents)
# Start training
with Logger(args.log_dir) as logger:
for episode in range(args.num_episodes):
if args.algorithm == 'nfsp':
agents[0].sample_episode_policy()
# Generate data from the environment
trajectories, payoffs = env.run(is_training=True)
# Reorganaize the data to be state, action, reward, next_state, done
trajectories = reorganize(trajectories, payoffs)
# Feed transitions into agent memory, and train the agent
# Here, we assume that DQN always plays the first position
# and the other players play randomly (if any)
for ts in trajectories[0]:
agent.feed(ts)
# Evaluate the performance. Play with random agents.
if episode % args.evaluate_every == 0:
logger.log_performance(env.timestep, tournament(env, args.num_eval_games)[0])
# Get the paths
csv_path, fig_path = logger.csv_path, logger.fig_path
# Plot the learning curve
plot_curve(csv_path, fig_path, args.algorithm)
# Save model
save_path = os.path.join(args.log_dir, 'model.pth')
torch.save(agent, save_path)
print('Model saved in', save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
parser.add_argument('--env', type=str, default='leduc-holdem',
choices=['blackjack', 'leduc-holdem', 'limit-holdem', 'doudizhu', 'mahjong', 'no-limit-holdem', 'uno', 'gin-rummy'])
parser.add_argument('--algorithm', type=str, default='dqn', choices=['dqn', 'nfsp'])
parser.add_argument('--cuda', type=str, default='')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_episodes', type=int, default=5000)
parser.add_argument('--num_eval_games', type=int, default=2000)
parser.add_argument('--evaluate_every', type=int, default=100)
parser.add_argument('--log_dir', type=str, default='experiments/leduc_holdem_dqn_result/')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
train(args)