-
Notifications
You must be signed in to change notification settings - Fork 2
/
client-no-wait.py
68 lines (57 loc) · 2.21 KB
/
client-no-wait.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import logging, grpc, time
import numpy as np
import server_tools_pb2
import server_tools_pb2_grpc
PORT = '50051'
f = open("IP.txt")
IP = f.read()
if IP[-1] == '\n':
IP = IP[:-1]
f.close()
def run():
# Get a handle to the server
channel = grpc.insecure_channel(IP+':'+PORT)
stub = server_tools_pb2_grpc.MnistServerStub(channel)
# Get a client ID which you need to talk to the server
try:
response = stub.RequestClientID(server_tools_pb2.NullParam())
except:
print("Connection to the server could not be established. Press enter to try again.")
return
client_id = response.new_id
# Load the image from image.bmp
image = open('image.bmp', 'rb')
b = image.read()[54:]# Remove header
image.close()
assert(len(b)==28*28*3)# Make sure the image has no transparency and is 28x28
data = np.ndarray((1, 28, 28, 1))
for y in range(28):
for x in range(28):
i = 3*(28*y+x)
data[0][27-y][x][0] = 1 - (int(b[i])+int(b[i+1])+int(b[i+2])) / 3 / 255
data = data.tostring()
#Pass the data to the server
print("Submitting image")
start_time=time.time()
id_package = stub.StartJobNoWait(server_tools_pb2.DataMessage(images=data, client_id = client_id, batch_size=32))
response = stub.ProbeJob(id_package)
# Wait for the server to finish prediction
print("Checking in with server")
while not response.complete:
response = stub.ProbeJob(id_package)
if response.error != '':
break
# Find the most likely prediction and print it
original_array = np.frombuffer(response.prediction).reshape(1, 10)
whole_time = time.time() - start_time
result = list(original_array[0])
print("Prediction is:", result.index(max(result)))
print("Total time:", whole_time)
print("Predict time:", response.infer_time)
print("Fraction of time spent not predicting:", (1 - response.infer_time / whole_time) * 100, '%')
channel.close()
if __name__ == '__main__':
logging.basicConfig()
# Repeat so that you can change the image
while input('\nChange image.bmp if you like') == '':
run()