-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert.py
73 lines (62 loc) · 2.06 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from argparse import ArgumentParser
import numpy as np
import torch
import torch.onnx
from model import KataGoInferenceModel
from netparser import read_model
def main(args):
model_spec = read_model(args.model, args.model_config)
model = KataGoInferenceModel(model_spec)
print("Model building completed")
model.fill_weights()
dummy_input_binary = torch.randn(10, 22, 19, 19)
dummy_input_binary[:, 0, :, :] = 1.0
dummy_input_global = torch.randn(10, 19)
input_names = ["input_binary", "input_global"]
output_names = [
"output_policy",
"output_value",
"output_miscvalue",
"output_ownership",
]
torch.onnx.export(
model,
(dummy_input_binary, dummy_input_global),
args.output,
export_params=True,
verbose=True,
opset_version=10,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_binary": {0: "batch_size", 2: "y_size", 3: "x_size"},
"input_global": {0: "batch_size"},
"output_policy": {0: "batch_size", 1: "board_area + 1"},
"output_value": {0: "batch_size"},
"output_miscvalue": {0: "batch_size"},
"output_ownership": {0: "batch_size", 2: "y_size", 3: "x_size"},
},
)
print(f"ONNX model saved in {args.output}")
if __name__ == "__main__":
description = """
Convert KataGo .bin model to .onnx file.
"""
parser = ArgumentParser(description)
parser.add_argument(
"--model", type=str, required=True, help="KataGo .bin network file location"
)
parser.add_argument(
"--model-config",
type=str,
required=True,
help="KataGo model.config.json file location (usually archived in the .zip file)",
)
parser.add_argument(
"--output", type=str, default=None, help="Output .onnx network file location"
)
args = parser.parse_args()
if args.output is None:
args.output = args.model.replace(".bin", ".onnx")
main(args)