Skip to content

Commit 7bd36b9

Browse files
committed
fix: model and train
1 parent a4939f2 commit 7bd36b9

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

core

Submodule core updated 1 file

model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, input_channel_num, num_actions, extra_size):
1010
super().__init__()
1111
self.channels = input_channel_num
1212
self.embeddings = nn.ModuleList(
13-
[nn.Embedding(9, 16) for _ in range(input_channel_num)])
13+
[nn.Embedding(10, 16) for _ in range(input_channel_num)])
1414
self.conv1 = nn.Conv2d(64, 64, kernel_size=8, stride=4)
1515
self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
1616
self.bn = nn.BatchNorm2d(64)
@@ -50,7 +50,7 @@ def __init__(self, input_channel_num, num_actions, extra_size):
5050
super().__init__()
5151
self.channels = input_channel_num
5252
self.embeddings = nn.ModuleList(
53-
[nn.Embedding(9, 16) for _ in range(input_channel_num)])
53+
[nn.Embedding(10, 16) for _ in range(input_channel_num)])
5454

5555
self.conv1 = nn.Conv2d(64, 64, kernel_size=8, stride=4)
5656
self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
@@ -88,7 +88,7 @@ def forward(self, x, y):
8888

8989
# test the shape of the output
9090
if __name__ == "__main__":
91-
rand_input = torch.rand(1, 4, 38, 38)
91+
rand_input = torch.rand(1, 4, 41, 41)
9292
extra_input = torch.rand(1, 10)
9393
pacman_net = PacmanNet(4, 5, 10)
9494
res = pacman_net(rand_input, extra_input)

train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# initialize networks
2626
policy_net_pacman = PacmanNet(4, 5, 40)
2727
target_net_pacman = PacmanNet(4, 5, 40)
28-
policy_net_pacman.load_state_dict(torch.load("pacman.pth"))
28+
# policy_net_pacman.load_state_dict(torch.load("pacman.pth"))
2929
target_net_pacman.load_state_dict(policy_net_pacman.state_dict())
3030
target_net_pacman.eval()
3131

@@ -34,7 +34,7 @@
3434

3535
policy_net_ghost = GhostNet(4, 5, 40)
3636
target_net_ghost = GhostNet(4, 5, 40)
37-
policy_net_ghost.load_state_dict(torch.load("ghost.pth"))
37+
# policy_net_ghost.load_state_dict(torch.load("ghost.pth"))
3838
target_net_ghost.load_state_dict(policy_net_ghost.state_dict())
3939
target_net_ghost.eval()
4040

@@ -87,24 +87,24 @@ def state_dict_to_tensor(state_dict):
8787
board = np.array(board)
8888
size = board.shape[0]
8989
# print(board)
90-
# pad board to 38x38
91-
padding_num = 38 - size
90+
# pad board to 41x41
91+
padding_num = 41 - size
9292
board = np.pad(board, pad_width=(0, padding_num),
9393
mode="constant", constant_values=0)
9494
# pacman position matrix
95-
pacman_pos = np.zeros((38, 38))
95+
pacman_pos = np.zeros((41, 41))
9696
if "pacman_pos" in state_dict:
9797
pacman_pos[state_dict["pacman_pos"][0] + padding_num][
9898
state_dict["pacman_pos"][1] + padding_num
9999
] = 1
100100

101101
# ghost position matrix
102-
ghost_pos = np.zeros((38, 38))
102+
ghost_pos = np.zeros((41, 41))
103103
if "ghost_pos" in state_dict:
104104
for ghost in state_dict["ghost_pos"]:
105105
ghost_pos[ghost[0] + padding_num][ghost[1] + padding_num] = 1
106106

107-
portal_pos = np.zeros((38, 38))
107+
portal_pos = np.zeros((41, 41))
108108
if "portal" in state_dict:
109109
portal = state_dict["portal"]
110110
if portal[0] != -1 and portal[1] != -1:

0 commit comments

Comments
 (0)