forked from LPD-EPFL/CS453-2023-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
submit.py
236 lines (211 loc) · 7.32 KB
/
submit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# coding: utf-8
###
# @file submit.py
# @author Sébastien Rouault <[email protected]>
#
# @section LICENSE
#
# Copyright © 2018-2019 École Polytechnique Fédérale de Lausanne (EPFL).
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# any later version. Please see https://gnu.org/licenses/gpl.html
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# @section DESCRIPTION
#
# Client for the automated performance measurement tool of CS-453.
###
if __name__ != "__main__":
raise RuntimeError("Script " + repr(__file__) + " is to be used as the main module only")
# ---------------------------------------------------------------------------- #
# Python version check and imports
import sys
if sys.version_info.major != 3 or sys.version_info.minor < 5:
print("WARNING: python interpreter not supported, please install version 3.5 or compatible (e.g. 3.6, 3.7, etc)")
import argparse
import atexit
import pathlib
import socket
# ---------------------------------------------------------------------------- #
# Configuration
version_uid = b"\x00\x00\x00\x02" # Unique version identifier (must be identical in compatible server)
default_host = "lpdxeon2680.epfl.ch" # Default server hostname or IPv4
default_port = 9997 # Default server TCP port
max_codesize = 100000 # Max code size (in bytes) before AND after deflate (the same as in the server); modifying this value won't change the behavior of the server ;)
# ---------------------------------------------------------------------------- #
# Command line
# Description
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--uuid",
type=str,
required=True,
help="Secret user unique identifier")
parser.add_argument("--host",
type=str,
default=default_host,
help="Server hostname")
parser.add_argument("--port",
type=int,
default=default_port,
help="Server TCP port")
parser.add_argument("--force-update",
action="store_true",
help="Update your submission to the server; no validation of the correctness of your submission will be performed")
parser.add_argument("--download",
action="store_true",
help="Download your submission currently kept for grading under 'zippath' (instead of submitting a new one); if '--force-update' is used, a \"swap\" is performed")
parser.add_argument("zippath",
type=str,
help="Path to a zip file containing your library code, or that will be overwritten with '--download'")
# Command line parsing
args = parser.parse_args(sys.argv[1:])
# ---------------------------------------------------------------------------- #
# Socket helper
def socket_encode_size(size):
""" Encode the given number into 4 bytes.
Args:
size Size to encode
Returns:
Encoded size
"""
# Assertions
if size < 0 or size >= (1 << 32):
raise OverflowError
# Encoding
res = bytearray(4)
for i in range(4):
res[i] = size & 0xFF
size >>= 8
return res
def socket_decode_size(size):
""" Decode the given 4-byte encoded size into an integer.
Args:
size Encoded size
Returns:
Decoded size
"""
# Assertions
if len(size) != 4:
raise OverflowError
# Decoding
res = 0
for i in range(4):
res <<= 8
res += size[3 - i]
return res
def socket_consume(fd, size):
""" Repeatedly read the socket until the given size has been received.
Args:
fd Socket file descriptor to read
size Size to read from the socket
"""
data = bytes()
while size > 0:
recv = fd.recv(size)
if len(recv) <= 0:
raise IOError("No more data in the socket")
data += recv
size -= len(recv)
return data
def socket_recvfield(fd, maxsize, exact=False):
""" Receive a field from a socket.
Args:
fd Socket file descriptor to read
maxsize Maximum/exact size to accept
exact Whether field size was exact
Returns:
Received field bytes
"""
size = socket_decode_size(socket_consume(fd, 4))
if size > maxsize:
raise IOError("Field is too large")
elif exact and size < maxsize:
raise IOError("Field is too small")
return socket_consume(fd, size)
def socket_sendfield(fd, data):
""" Send a field through a socket.
Args:
fd Socket file descriptor to write
data Data bytes to send
"""
if fd.sendall(socket_encode_size(len(data))) is not None or fd.sendall(data) is not None:
raise IOError("Send failed")
# ---------------------------------------------------------------------------- #
# Client
# Open connection with the server
client_fd = None
for af, socktype, proto, canonname, sa in socket.getaddrinfo(args.host, args.port, socket.AF_INET, socket.SOCK_STREAM):
try:
client_fd = socket.socket(af, socktype, proto)
except OSError as msg:
client_fd = None
continue
try:
client_fd.connect(sa)
except OSError as msg:
client_fd.close()
client_fd = None
continue
break
if client_fd is None:
print("""Unable to connect to %s:%s
Message to the students:
Please make sure you operate from EPFL's network (i.e. either on campus or through the VPN).
If you do, it is not unlikely that the server machine is temporarily being used for another purpose.
Please retry later; contact the TAs only if the problem persists for more than a day.""" % (args.host, args.port))
exit(1)
atexit.register(lambda: client_fd.close())
# Check version
if socket_recvfield(client_fd, len(version_uid)) != version_uid:
print("Protocol version mismatch with the server, please pull/download the latest version of the client.")
exit(1)
# Send the secret user identifier
socket_sendfield(client_fd, args.uuid.encode())
res = socket_recvfield(client_fd, 1, exact=True)
msg = {1: "Invalid user secret identifier", 2: "Unknown user secret identifier", 3: "User is already logged in"}
if res[0] in msg:
print(msg[res[0]]) # Unsuccessful identifications are logged
exit(1)
# Send mode of submission
submit_mode = (1 if args.download else 0) + (2 if args.force_update else 0)
socket_sendfield(client_fd, submit_mode.to_bytes(1, "little"))
# Process according to mode of submission
if args.force_update or not args.download:
# Send zip file
zip_path = pathlib.Path(args.zippath)
if not zip_path.exists():
print("File %r cannot be accessed" % str(zip_path))
exit(1)
socket_sendfield(client_fd, zip_path.read_bytes())
if args.download:
# Receive zip file
with pathlib.Path(args.zippath).open("wb") as fd:
fd.write(socket_recvfield(client_fd, max_codesize))
# Read until the socket is closed (for user feedback/error messages)
try:
prev = bytes()
while True:
data = client_fd.recv(256)
if len(data) <= 0:
# Here 'prev' should be empty or not enough data to decode 'prev' correctly: so do nothing
break
data = prev + data
prev = bytes()
try:
text = data.decode()
except UnicodeDecodeError as err:
prev = data[err.start:]
text = data[:err.start].decode()
if len(text) > 0:
sys.stdout.write(text)
sys.stdout.flush()
except ConnectionResetError:
pass
except KeyboardInterrupt:
pass