forked from rspeer/dominionstats
-
Notifications
You must be signed in to change notification settings - Fork 17
/
game_state_features.py
201 lines (160 loc) · 6.85 KB
/
game_state_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/python
# -*- coding: utf-8 -*-
""" Convert game documents to a format easily readable by R."""
import itertools
import card_info as ci
import game
import random
import utils
def nice_feature_name(n):
return n.replace(' ', '_').replace("'", '')
def composition_deck_extractor(deck_comp, game_state, player):
ret = []
for card in ci.card_names():
ret.append(deck_comp.get(card, 0))
return ret
composition_deck_extractor.feature_names = map(nice_feature_name,
ci.card_names())
def score_deck_extractor(deck_comp, game_state, player):
return [game_state.player_score(player)]
def deck_size_deck_extractor(deck_comp, game_state, player):
return [sum(deck_comp.itervalues())]
def action_balance_deck_extractor(deck_comp, game_state, player):
ret = 0
for card, quant in deck_comp.iteritems():
ret += (ci.num_plus_actions(card) - ci.is_action(card)) * quant
return [ret / (sum(deck_comp.itervalues()) or 1)]
def unique_deck_extractor(deck_comp, game_state, player):
return [len(deck_comp)]
def outcome_special_extractor(g, game_state):
turn_order = game_state.player_turn_order()
win_points = g.get_player_deck(turn_order[0]).WinPoints()
return [win_points]
def turn_tiebreaker_common_extractor(g, game_state):
current_turn_order = game_state.player_turn_order()
return [g.get_player_deck(current_turn_order[0]).TurnOrder()]
def num_piles_empty_common_extractor(g, game_state):
ret = 0
for card, quant in game_state.supply.iteritems():
if quant == 0:
ret += 1
return [ret]
def num_piles_low_common_extractor(g, game_state):
ret = 0
for card, quant in game_state.supply.iteritems():
if quant <= 2:
ret += 1
return [ret]
def prov_or_colony_low_extractor(g, game_state):
for card, quant in game_state.supply.iteritems():
if (card == 'Province' or card == 'Colony') and 1 <= quant <= 2:
return [1]
return [0]
def turn_no_common_extractor(g, game_state):
return [game_state.turn_index()]
def supply_common_extractor(g, game_state):
ret = []
for card in ci.card_names():
ret.append(game_state.supply.get(card, 0))
return ret
supply_common_extractor.feature_names = map(nice_feature_name, ci.card_names())
def make_extractor_list(suffix):
extractor_names = [n[:-len(suffix)] for n in
globals() if n.endswith(suffix)]
extractors = [eval(n + suffix) for n in extractor_names]
for extractor, name in itertools.izip(extractors, extractor_names):
if not hasattr(extractor, 'feature_names'):
extractor.feature_names = [name]
return extractors
_deck_extractor_list = make_extractor_list('deck_extractor')
_common_extractor_list = make_extractor_list('common_extractor')
def feature_names(feature_extractor_list):
ret = []
for extractor in feature_extractor_list:
ret.extend(extractor.feature_names)
return ret
def state_to_features(g, game_state):
output_list = []
for common_extractor in _common_extractor_list:
output_list.extend(common_extractor(g, game_state))
per_player_features = []
for player_name in game_state.player_turn_order():
cur_player_features = []
deck_comp = game_state.get_deck_composition(player_name)
for extractor in _deck_extractor_list:
cur_player_features.extend(
extractor(deck_comp, game_state, player_name))
output_list.extend(cur_player_features)
per_player_features.append(cur_player_features)
p1, p2 = per_player_features
for p1_feature_val, p2_feature_val in itertools.izip(p1, p2):
output_list.append(p1_feature_val - p2_feature_val)
output_list.extend(outcome_special_extractor(g, game_state))
return output_list
def output_state(state, output_file, sep=' '):
formatted_str = sep.join(map(unicode, state))
output_file.write(formatted_str)
output_file.write('\n')
def get_all_feature_names():
header = feature_names(_common_extractor_list)
for player_label in ['my', 'opp', 'diff']:
for feature_name in feature_names(_deck_extractor_list):
header.append(player_label + '_' + feature_name)
header.append('outcome_')
return header
def write_r_header(output_file):
outputted = ' '.join(get_all_feature_names()) + '\n'
output_file.write(outputted)
def write_weka_header(output_file, force_classification):
output_file.write('@RELATION isotropic_games\n\n')
for feature_name in get_all_feature_names()[:-1]:
output_file.write('@ATTRIBUTE %s NUMERIC\n' % feature_name)
assert force_classification
output_file.write('@ATTRIBUTE outcome_ {0,1}')
output_file.write('\n@DATA\n')
def output_libsvm_state(state, output_file):
if state[-1] == 0:
output_file.write('-1 ')
else:
output_file.write('1 ')
for index, value in enumerate(state[:-1]):
if value != 0:
output_file.write('%d:%d ' % (index + 1, value))
output_file.write('\n')
def main():
c = utils.get_mongo_connection()
force_classification = True
prefix = 'data/test_small_'
# limit = 10000
r_output_file = open(prefix + 'r_format.data', 'w')
weka_output_file = open(prefix + 'games.arff', 'w')
librf_output_file = open(prefix + 'librf_games.csv', 'w')
librf_labels_file = open(prefix + 'librf_games_labels.txt', 'w')
libsvm_output_file = open(prefix + 'libsvm_games.txt', 'w')
write_r_header(r_output_file)
write_weka_header(weka_output_file, force_classification)
for raw_game in utils.progress_meter(
c.test.games.find(
{'_id': {'$gt': 'game-2010-10'} }
).limit(20000), 100):
g = game.Game(raw_game)
if g.dubious_quality() or len(g.get_player_decks()) != 2:
continue
if force_classification and g.get_player_decks()[0].WinPoints() == 1.0:
print 'skipping tie'
saved_turn_ind = random.randint(0, len(g.get_turns()) - 1)
for ind, game_state in enumerate(g.game_state_iterator()):
# if ind == saved_turn_ind:
encoded_state = state_to_features(g, game_state)
if force_classification:
encoded_state[-1] = int(encoded_state[-1] / 2)
#output_state(encoded_state, r_output_file, ' ')
#output_state(encoded_state, weka_output_file, ',')
#output_state(encoded_state[:-1], librf_output_file, ',')
#librf_labels_file.write('%d\n' % encoded_state[-1])
output_libsvm_state(encoded_state, libsvm_output_file)
#else:
# assert False, ('did not find turn %d in %s' % (saved_turn_ind,
# game.get_id()))
if __name__ == '__main__':
main()