-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert.py
116 lines (96 loc) · 3.88 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import h5py
import numpy as np
from config import CHANNELS
from policy import PolicyValueModelResNet as PolicyValueModel
def convert_pretrained_weights(
src_weights_file,
dst_weights_file,
src_width=8,
dst_width=15,
src_height=8,
dst_height=15,
):
model_src = PolicyValueModel(src_width, src_height)
model_src.build(input_shape=(None, src_width, src_height, CHANNELS))
model_src.load_weights(src_weights_file)
model_dst = PolicyValueModel(dst_width, dst_height)
model_dst.build(input_shape=(None, dst_width, dst_height, CHANNELS))
assert len(model_src.cnn_layers) == len(model_dst.cnn_layers)
for i in range(len(model_src.cnn_layers)):
layer_src = model_src.cnn_layers[i]
layer_dst = model_dst.cnn_layers[i]
layer_dst.set_weights(layer_src.get_weights())
model_dst.save_weights(dst_weights_file)
def convert_pretrained_buffer(
src_buffer_file,
dst_buffer_file,
src_width=8,
dst_width=15,
src_height=8,
dst_height=15,
):
assert dst_height >= src_height
assert dst_width >= src_width
f_src = h5py.File(src_buffer_file, "r")
f_dst = h5py.File(dst_buffer_file, "w")
states_src = f_src["states"][...]
mcts_probs_src = f_src["mcts_probs"][...]
buffer_length = states_src.shape[0]
start_width_idx = (dst_width - src_width) // 2
start_height_idx = (dst_height - src_height) // 2
states_dst = np.zeros(
shape=(buffer_length, dst_width, dst_height, CHANNELS),
dtype=states_src.dtype,
)
states_dst[
:,
start_width_idx : start_width_idx + src_width,
start_height_idx : start_height_idx + src_height,
] = states_src[:]
# 最后一根轴只能是全 1 或全 0
states_dst[:, :, :, -1] = states_src[:, 0:1, 0:1, -1]
mcts_probs_dst = np.zeros(
shape=(buffer_length, dst_width * dst_width),
dtype=mcts_probs_src.dtype,
)
mcts_probs_dst = mcts_probs_dst.reshape((buffer_length, dst_width, dst_width))
mcts_probs_dst[
:,
start_width_idx : start_width_idx + src_width,
start_height_idx : start_height_idx + src_height,
] = mcts_probs_src[:].reshape((buffer_length, src_width, src_width))
mcts_probs_dst = mcts_probs_dst.reshape((buffer_length, dst_width * dst_width))
f_dst["states"] = states_dst
f_dst["mcts_probs"] = mcts_probs_dst
f_dst["rewards"] = f_src["rewards"][...]
f_src.close()
f_dst.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gomoku AlphaZero Weights Converter")
parser.add_argument("--src-width", default=8, type=int, help="源棋盘水平宽度")
parser.add_argument("--src-height", default=8, type=int, help="源棋盘竖直宽度")
parser.add_argument("--dst-width", default=15, type=int, help="目标棋盘水平宽度")
parser.add_argument("--dst-height", default=15, type=int, help="目标棋盘竖直宽度")
parser.add_argument("--src-weights", default="./data/model-8x8#5.h5", help="源预训练权重存储位置")
parser.add_argument("--dst-weights", default="./data/model-15x15#5.h5", help="目标预训练权重存储位置")
parser.add_argument("--src-buffer", default="./data/buffer-8x8#5.h5", help="源经验池存储位置")
parser.add_argument("--dst-buffer", default="./data/buffer-15x15#5.h5", help="目标经验池存储位置")
args = parser.parse_args()
# 小棋盘预训练数据迁移到大棋盘
convert_pretrained_weights(
args.src_weights,
args.dst_weights,
src_width=args.src_width,
dst_width=args.dst_width,
src_height=args.src_height,
dst_height=args.dst_height,
)
convert_pretrained_buffer(
args.src_buffer,
args.dst_buffer,
src_width=args.src_width,
dst_width=args.dst_width,
src_height=args.src_height,
dst_height=args.dst_height,
)