-
Notifications
You must be signed in to change notification settings - Fork 0
/
AI出牌方案输出
38 lines (37 loc) · 1.39 KB
/
AI出牌方案输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
labels=['等待','出牌','不出']
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#若有gpu可用则用gpu
model = models.resnet50(pretrained=False)
fc_inputs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(fc_inputs, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, config.NUM_CLASSES),
nn.LogSoftmax(dim=1)
)
pthfile=config.TRAINED_BEST_MODEL
checkpoint = torch.load(pthfile)
model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
# test(model, test_load)
model.to(device).eval()
def detect_pass(pos):
img = pyautogui.screenshot(region=pos)
# path="datas\state.png"
time =datetime.datetime.now().strftime(TIMESTAMP)
path="datas\states\state"+'_'+time+'.png'
img.save(path)
# path="datas/states/state_20210807160852.png"
src = cv2.imread(path) # aeroplane.jpg
image = cv2.resize(src, (224, 224))
image = np.float32(image) / 255.0
image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
image = image.transpose((2, 0, 1))
input_x = torch.from_numpy(image).unsqueeze(0).to(device)
pred = model(input_x)
pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
pred_index=int(pred_index)
print(pred_index)
return pred_index