forked from tseip/fourinarow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathninarow_vectorized_feature_evaluator.h
219 lines (200 loc) · 7.33 KB
/
ninarow_vectorized_feature_evaluator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#ifndef NINAROW_VECTORIZED_FEATURE_EVALUATOR_H_INCLUDED
#define NINAROW_VECTORIZED_FEATURE_EVALUATOR_H_INCLUDED
#include <Eigen/Dense>
#include <unordered_map>
#include "ninarow_heuristic_feature.h"
#include "player.h"
namespace NInARow {
/**
* Counts the number of overlapping bits between a given bitset and a vector of
* known bitsets in an efficient, vectorized way. Uses size_t instead of actual
* bits to keep track of total overlap counts in the final evaluation.
*
* @tparam N The maximum length of all of the bitsets in the known vector of
* bitsets.
*/
template <std::size_t N>
class VectorizedBitsetCounter {
private:
/**
* Represents the known matrix of bitsets - an matrix of size N * M,
* where M is the number of bitsets that have been registered for
* evaluation. Each bitset is stored as a column in the matrix.
*/
Eigen::Matrix<std::size_t, N, Eigen::Dynamic> bitset_matrix;
/**
* Converts a bitset to a one-dimensional vector of size_ts
*
* @param bitset The set of bits to convert.
*
* @return A vector of size_t, where each set element of the bitset
* corresponds to a 1 in the vector.
*/
static Eigen::Vector<std::size_t, N> bitset_to_vector(
const std::bitset<N> &bitset) {
Eigen::Vector<std::size_t, N> vector;
for (std::size_t i = 0; i < N; ++i) {
vector(i) = static_cast<std::size_t>(bitset[i]);
}
return vector;
}
/**
* Converts a list of bitsets to a matrix.
*
* @param bitsets The list of bitsets to convert.
*
* @return A matrix of size_t, where each row of the matrix corresponds to a
* bitset from the input.
*/
static Eigen::Matrix<std::size_t, Eigen::Dynamic, N> bitsets_to_matrix(
const std::vector<std::bitset<N>> &bitsets) {
Eigen::Matrix<std::size_t, Eigen::Dynamic, N> matrix(0, N);
matrix.conservativeResize(bitsets.size(), Eigen::NoChange);
for (std::size_t i = 0; i < bitsets.size(); ++i) {
matrix.row(i) = bitset_to_vector(bitsets[i]);
}
return matrix;
}
public:
/**
* Constructor.
*/
VectorizedBitsetCounter() : bitset_matrix(N, 0) {}
/**
* Adds a bitset into our known pool. After this function is called, each
* query will return an additional line representing the bit overlap count
* with this bitset.
*
* @param bitset The bitset to add.
*/
void register_bitset(const std::bitset<N> &bitset) {
bitset_matrix.conservativeResize(Eigen::NoChange, bitset_matrix.cols() + 1);
bitset_matrix.col(bitset_matrix.cols() - 1) = bitset_to_vector(bitset);
}
/**
* Queries all of the added bitsets against a list of new bitsets. Returns a
* vector of vectors, where each top-level vector corresponds to a single
* bitset passed in, and each element of the subvectors corresponds to a count
* of the overlapping bits between each line of our registered bitsets and the
* given bitset.
*
* @param bitsets The bitsets to query against.
*
* @return A list of lists of bit overlap counts, where each element
* corresponds to the bit overlap count for each registered bitset against
* each given bitset.
*/
std::vector<std::vector<std::size_t>> query(
std::vector<std::bitset<N>> bitsets) const {
const auto m = bitsets_to_matrix(bitsets);
const Eigen::Matrix<std::size_t, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
count_results = m * bitset_matrix;
std::vector<std::vector<std::size_t>> output;
for (std::size_t i = 0; i < bitsets.size(); ++i) {
const auto row = count_results.row(i);
output.emplace_back(row.data(), row.data() + row.size());
}
return output;
}
};
/**
* Registers a number of features that can all be evaluated simultaneously and
* efficiently on given boards.
*
* @tparam Board The board that the feature will evaluate.
*/
template <typename Board>
class VectorizedFeatureEvaluator {
private:
/**
* The number of features we're tracking.
*/
std::size_t feature_count;
/**
* A counter representing the set of all of the pieces corresponding to all of
* the features we're tracking. (A feature comprises pieces and spaces.) Each
* line of this counter represents one feature's pieces.
*/
VectorizedBitsetCounter<Board::get_board_size()> feature_pieces_bitsets;
/**
* A counter representing the set of all of the spaces corresponding to all of
* the features we're tracking. (A feature comprises pieces and spaces.) Each
* line of this counter represents one feature's spaces.
*/
VectorizedBitsetCounter<Board::get_board_size()> feature_spaces_bitsets;
public:
/**
* Constructor.
*/
VectorizedFeatureEvaluator()
: feature_count(0), feature_pieces_bitsets(), feature_spaces_bitsets() {}
/**
* Adds a new feature to the evaluator.
*
* @param feature The feature to add.
*
* @return The total number of features this evaluator is tracking.
*/
std::size_t register_feature(const HeuristicFeature<Board> &feature) {
feature_pieces_bitsets.register_bitset(feature.pieces.positions);
feature_spaces_bitsets.register_bitset(feature.spaces.positions);
return feature_count++;
}
/**
* Given a list of boards and a player, count the number of pieces that the
* player has on each board which overlap with each of our registered
* features' pieces.
*
* @param boards The boards to evaluate.
* @param player The player whose pieces we are evaluating.
*
* @return A list of counts representing the number of pieces that the
* player has on the board that overlap with each feature in order.
*/
std::vector<std::vector<std::size_t>> query_pieces(
const std::vector<Board> &boards, Player player) const {
std::vector<std::bitset<Board::get_board_size()>> positions;
positions.reserve(boards.size());
for (const auto &board : boards) {
positions.push_back(board.get_pieces(player).positions);
}
return feature_pieces_bitsets.query(positions);
}
/**
* Given a list of boards, count the number of spaces on each board which
* overlap with each of our registered features' spaces.
*
* @param boards The boards to evaluate.
*
* @return A list of counts representing the amount of overlap between
* between the board's spaces and each feature's spaces.
*/
std::vector<std::vector<std::size_t>> query_spaces(
const std::vector<Board> &boards) const {
std::vector<std::bitset<Board::get_board_size()>> spaces;
spaces.reserve(boards.size());
for (const auto &board : boards) {
spaces.push_back(board.get_spaces().positions);
}
return feature_spaces_bitsets.query(spaces);
}
/**
* Helper functions for calling query_pieces/spaces on single board inputs
* easily.
*
* @{
*/
std::vector<std::size_t> query_pieces(const Board &board,
Player player) const {
return query_pieces(std::vector<Board>{board}, player)[0];
}
std::vector<std::size_t> query_spaces(const Board &board) const {
return query_spaces(std::vector<Board>{board})[0];
}
/**
* @}
*/
};
} // namespace NInARow
#endif // NINAROW_VECTORIZED_FEATURE_EVALUATOR_H_INCLUDED