-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdeepmind_mcts.py
290 lines (240 loc) · 7.77 KB
/
deepmind_mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# I absolutely hate this sys path stuff
import sys
sys.path.append(sys.path[0] + "/..")
import chess
import math
import numpy as np
import argparse
import datetime
import operator
import random
import copy
import pychess_utils as util
from random import choice
from chess import pgn, uci
from collections import defaultdict
from rpc_client import PredictClient
DRAW = 'draw'
# This constant should change exploration tendencies
CPUCT = 1.5
# Not sure these really boost performance much, but I think they should
# help if games are played back-to-back
prediction_cache = {}
value_cache = {}
ADDRESS = util.get_address()
PORT = util.get_port()
class Edge:
def __init__(self, node, move, prob, simulations=0, total_action_value=0, action_value=0):
self.node = node
self.prob = prob
self.total_action_value = total_action_value
self.action_value = action_value
self.simulations = simulations
self.move = move
@property
def move(self):
return self.__move
@move.setter
def move(self, move):
self.__move = move if type(move) == chess.Move else None
@property
def prob(self):
return self.__prob
@prob.setter
def prob(self, prob):
self.__prob = prob if prob >= 0 else 0
@property
def simulations(self):
return self.__simulations
@simulations.setter
def simulations(self, sims):
self.__simulations = sims if sims >=0 else 0
def get_siblings(self):
return [y for x, y in self.node.children if y != self]
def total_sims_at_depth(self):
sims = self.simulations
for sibling in self.get_siblings():
sims += sibling.simulations
return sims
def get_confidence(self):
term1 = CPUCT*self.prob
term2 = math.sqrt(self.total_sims_at_depth())/(1 + self.simulations)
return term1*term2
class Node:
def __init__(self, color, parent=None, position=None):
self.position = position
self.color = color
self.parent = parent
self.children = []
@property
def position(self):
return self.__position
@position.setter
def position(self, position):
self.__position = position if type(position) == chess.Board else None
@property
def children(self):
return self.__children
@children.setter
def children(self, children):
self.__children = children
@property
def parent(self):
return self.__parent
@parent.setter
def parent(self, parent):
self.__parent = parent
class MCTS:
# 1600 Iterations is simply too intense for this machine...
# Need to look into splitting into threads, queuing requests and cloud resources
ITERATIONS_PER_BUILD = 100
ITER_TIME = 5
def __init__(self, startpos=chess.Board(), iterations=None, iter_time=None,
prev_mcts=None, temperature=True, version=0, startcolor=True):
self.version = version if version else util.latest_version()
# gRPC client to query the trained model at localhost:9000
# SERVER MUST BE RUNNING LOCALLY
self.__client = PredictClient(ADDRESS, PORT, 'ACZ', int(self.version))
self.startpos = startpos
if prev_mcts:
# This saves the statistics about this startpos from the prev_mcts
self.__root = prev_mcts.child_matching(self.startpos)
if not self.__root:
print("Could not find move in previous tree.")
self.__root = Node(startcolor, position=self.startpos)
else:
self.__root = Node(startcolor, position=self.startpos)
self.iterations = iterations if iterations else self.ITERATIONS_PER_BUILD
self.iter_time = iter_time if iter_time else self.ITER_TIME
self.temperature = temperature
def child_matching(self, position):
if not self.__root.children:
return None
for child, edge in self.__root.children:
if child.position == position:
return child
return None
def max_action_val_child(self, root):
max_val = -1*float("inf")
max_child = None
max_edge = None
for child, edge in root.children:
if edge.action_value + edge.get_confidence() >= max_val:
max_child = child
max_edge = edge
max_val = edge.action_value + edge.get_confidence()
return (max_child, max_edge)
def most_visited_child(self, root):
max_visits = 0
max_edge = None
choices = []
for child, edge in root.children:
if edge.simulations >= max_visits:
max_edge = edge
choices.append(edge)
return random.choice(choices)
def total_child_visits(self, root):
visits = 0
for child, edge in root.children:
visits += edge.simulations
return visits
@property
def iterations(self):
return self.__iterations
@iterations.setter
def iterations(self, iters):
self.__iterations = max(0, min(iters, 3200))
@property
def startpos(self):
return self.__startpos
@startpos.setter
def startpos(self, startpos):
self.__startpos = startpos if type(startpos) == chess.Board else chess.Board()
@property
def iter_time(self):
return self.__iter_time
@iter_time.setter
def iter_time(self, time):
self.__iter_time = time if time > 0 else 100
def search(self):
leaf = self.select_leaf(self.__root)
self.expand_tree(leaf)
if not value_cache.get(leaf.position.fen()):
try:
value_cache[leaf.position.fen()] = self.__client.predict(util.expand_position(leaf.position))[0]
except:
print("Prediction error, retrying...")
value_cache[leaf.position.fen()] = self.__client.predict(util.expand_position(leaf.position))[0]
self.backprop(leaf, value_cache[leaf.position.fen()])
def build(self, timed=False):
if timed == True:
begin = datetime.datetime.utcnow()
while datetime.datetime.utcnow() - begin < datetime.timedelta(seconds=self.iter_time):
self.search()
else:
for iteration in range(self.iterations):
self.search()
def select_leaf(self, root):
while root:
if not root.children:
return root
root = self.max_action_val_child(root)[0]
print("Shouldn't hit this point.")
return Node(True)
def expand_tree(self, leaf):
if leaf.position:
new_leaves = []
board = leaf.position
if not prediction_cache.get(board.fen()):
try:
prediction_cache[board.fen()] = self.__client.predict(util.expand_position(board), 'policy')
except:
print("Prediction error, retrying...")
prediction_cache[board.fen()] = self.__client.predict(util.expand_position(board), 'policy')
moves = list(board.legal_moves)
if len(moves) == 0:
print("No moves from position: " + board.fen() + "\n")
for selected_move in moves:
new_board = copy.deepcopy(board)
pred_index = util.get_prediction_index(selected_move)
new_edge = Edge(leaf, selected_move,
util.logit_to_prob(prediction_cache[board.fen()][pred_index])
)
new_board.push(selected_move)
new_node = Node(not leaf.color, parent=leaf, position=new_board)
leaf.children.append((new_node, new_edge))
new_leaves.append(new_node)
return new_leaves
else:
print("MCTS tried to expand with empty position.")
def backprop(self, leaf, value):
leaf = leaf.parent
while leaf:
path_edge = self.max_action_val_child(leaf)[1]
path_edge.simulations += 1
path_edge.total_action_value += value
path_edge.action_value = path_edge.total_action_value/path_edge.simulations
leaf = leaf.parent
def get_policy_string(self):
policy = []
total_vists = self.total_child_visits(self.__root)
for child, edge in self.__root.children:
prob = edge.simulations/total_vists
policy.append("("+str(util.get_prediction_index(edge.move))+":"+str(prob)+")")
return '#'.join(policy)
def best_move(self):
if self.temperature:
choices = []
chances = defaultdict(int)
for child, edge in self.__root.children:
choices += [edge.move]*edge.simulations
chances[edge.move.uci()] = edge.simulations
# This does a weighted random selection based on simulations
choice = random.choice(choices)
print("{0} was chosen with chance: {1:.4f} out of {2} options".format(
choice.uci(),
float(chances[choice.uci()])/float(len(choices)),
len(self.__root.children)))
return choice
else:
return self.most_visited_child(self.__root).move