Skip to content

Commit 083694e

Browse files
committed
Merge branch main
2 parents 05f2d8e + 6f6e844 commit 083694e

File tree

8 files changed

+406
-12
lines changed

8 files changed

+406
-12
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
**/__pycache__/
22
.DS_Store
3+
ghost.pth
4+
pacman.pth

.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "core"]
22
path = core
3-
url = [email protected]:PacMan-Logic/Logic-core.git
3+
url = [email protected]:PacMan-Logic/Logic-core.git

ai.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ def choose_move(self, game_state: GameState):
277277

278278
return [Direction.STAY.value] # 默认停留
279279

280-
281280
# TODO: 你需要实现一个ai函数
282281

283-
ai_func = PacmanAI().choose_move # TODO: 你需要把ai_func替换为自己的ai函数
282+
ai_func = PacmanAI().choose_move
284283
__all__ = ["ai_func"]

ai_rl.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from core.gamedata import GameState
2+
from core.GymEnvironment import PacmanEnv
3+
from model import *
4+
from train import state_dict_to_tensor
5+
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
8+
9+
class PacmanAI:
10+
def __init__(self, device=device):
11+
self.device = device
12+
self.pacman_net = PacmanNet(4, 5, 40)
13+
self.pacman_net.load_state_dict(torch.load("pacman.pth"))
14+
self.pacman_net.to(self.device)
15+
self.pacman_net.eval()
16+
17+
def __call__(self, game_state: GameState):
18+
state = game_state.gamestate_to_statedict()
19+
state_tensor, extra = state_dict_to_tensor(state)
20+
with torch.no_grad():
21+
op = (
22+
self.pacman_net(state_tensor.to(self.device), extra.to(self.device))
23+
.argmax(1)
24+
.cpu()
25+
)
26+
return [op.item()]
27+
28+
29+
if __name__ == "__main__":
30+
ai = PacmanAI()
31+
env = PacmanEnv()
32+
env.reset()
33+
state = env.game_state()
34+
35+
out = ai(state)
36+
print(out)

core

main.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import sys
21
import json
2+
import sys
33

4-
from ai_to_judger import pacman_to_judger
5-
from ai_to_judger import ghost_to_judger
64
from ai import *
7-
from core.GymEnvironment import PacmanEnv
5+
from ai_rl import PacmanAI
6+
from ai_to_judger import ghost_to_judger, pacman_to_judger
87
from core.gamedata import GameState
8+
from core.GymEnvironment import PacmanEnv
99
from utils.utils import write_to_judger
1010

1111

@@ -27,6 +27,7 @@ def __init__(self):
2727
id = int(input())
2828
self.id = id
2929
self.level_change = True
30+
self.eat_all_beans = False
3031

3132
def run(self, ai):
3233
while 1:
@@ -40,7 +41,7 @@ def run(self, ai):
4041
# 0号玩家发送信息
4142
pacman_op(self.env, ai)
4243

43-
# 1号玩家发送信息
44+
# 等待1号玩家发送信息
4445
get_info = input()
4546
print(f"receive info: {get_info}", file=sys.stderr)
4647

@@ -50,13 +51,13 @@ def run(self, ai):
5051
get_op_json = json.loads(get_op)
5152
pacman_action = get_op_json["pacman_action"]
5253
ghosts_action = get_op_json["ghosts_action"]
53-
board, score, self.level_change = self.env.step(
54+
info , pacman_reward , ghosts_reward , self.level_change , self.eat_all_beans = self.env.step(
5455
pacman_action, ghosts_action
5556
)
5657
else:
5758
# 当前为1号玩家
5859

59-
# 0号玩家发送信息
60+
# 等待0号玩家发送信息
6061
get_info = input()
6162
print(f"receive info: {get_info}", file=sys.stderr)
6263

@@ -69,12 +70,13 @@ def run(self, ai):
6970
get_op_json = json.loads(get_op)
7071
pacman_action = get_op_json["pacman_action"]
7172
ghosts_action = get_op_json["ghosts_action"]
72-
board, score, self.level_change = self.env.step(
73+
info , pacman_reward , ghosts_reward , self.level_change , self.eat_all_beans = self.env.step(
7374
pacman_action, ghosts_action
7475
)
7576

7677

7778
if __name__ == "__main__":
7879
print("init done", file=sys.stderr)
7980
controller = Controller()
81+
# ai_func = PacmanAI()
8082
controller.run(ai_func)

model.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
# state-value network for pacman
8+
class PacmanNet(nn.Module):
9+
def __init__(self, input_channel_num, num_actions, extra_size):
10+
super().__init__()
11+
self.channels = input_channel_num
12+
self.embeddings = nn.ModuleList(
13+
[nn.Embedding(9, 16) for _ in range(input_channel_num)])
14+
self.conv1 = nn.Conv2d(64, 64, kernel_size=8, stride=4)
15+
self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
16+
self.bn = nn.BatchNorm2d(64)
17+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=2)
18+
19+
self.encoder = nn.Linear(extra_size, 64)
20+
21+
self.fc1 = nn.Linear(64, 32)
22+
self.fc2 = nn.Linear(32, num_actions)
23+
24+
def forward(self, x, y):
25+
B, C, H, W = x.shape
26+
embedded_channels = []
27+
for i in range(self.channels):
28+
flattened_channel = x[:, i, :, :].view(B, -1).long()
29+
embedded_channel = self.embeddings[i](flattened_channel)
30+
embedded_channel = embedded_channel.view(
31+
B, 16, H, W)
32+
embedded_channels.append(embedded_channel)
33+
# Concatenate along the channel dimension
34+
x = torch.cat(embedded_channels, dim=1)
35+
36+
x = F.relu(self.conv1(x))
37+
x = F.relu(self.conv2(x))
38+
x = self.bn(x)
39+
x = F.relu(self.conv3(x))
40+
y = F.sigmoid(self.encoder(y))
41+
# print(x.shape)
42+
x = x.view(x.size(0), -1)
43+
x = F.relu(self.fc1(x+y))
44+
return self.fc2(x)
45+
46+
47+
# state-value network for ghost
48+
class GhostNet(nn.Module):
49+
def __init__(self, input_channel_num, num_actions, extra_size):
50+
super().__init__()
51+
self.channels = input_channel_num
52+
self.embeddings = nn.ModuleList(
53+
[nn.Embedding(9, 16) for _ in range(input_channel_num)])
54+
55+
self.conv1 = nn.Conv2d(64, 64, kernel_size=8, stride=4)
56+
self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
57+
self.bn = nn.BatchNorm2d(64)
58+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=2)
59+
60+
self.encoder = nn.Linear(extra_size, 64)
61+
62+
self.fc1 = nn.Linear(64, 32)
63+
self.fc2 = nn.Linear(32, num_actions*3)
64+
65+
def forward(self, x, y):
66+
B, C, H, W = x.shape
67+
embedded_channels = []
68+
for i in range(self.channels):
69+
flattened_channel = x[:, i, :, :].view(B, -1).long()
70+
embedded_channel = self.embeddings[i](flattened_channel)
71+
embedded_channel = embedded_channel.view(
72+
B, 16, H, W)
73+
embedded_channels.append(embedded_channel)
74+
# Concatenate along the channel dimension
75+
x = torch.cat(embedded_channels, dim=1)
76+
x = F.relu(self.conv1(x))
77+
x = F.relu(self.conv2(x))
78+
x = self.bn(x)
79+
x = F.relu(self.conv3(x))
80+
# print(x.shape)
81+
82+
y = F.sigmoid(self.encoder(y))
83+
84+
x = x.view(x.size(0), -1)
85+
x = F.relu(self.fc1(x+y))
86+
return self.fc2(x).view(-1, 3, 5)
87+
88+
89+
# test the shape of the output
90+
if __name__ == "__main__":
91+
rand_input = torch.rand(1, 4, 38, 38)
92+
extra_input = torch.rand(1, 10)
93+
pacman_net = PacmanNet(4, 5, 10)
94+
res = pacman_net(rand_input, extra_input)
95+
print(res.shape)
96+
97+
ghost_net = GhostNet(4, 5, 10)
98+
res = ghost_net(rand_input, extra_input)
99+
print(res.shape)

0 commit comments

Comments
 (0)