forked from avivkiss/warp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver_udt_manager.py
120 lines (94 loc) · 3.02 KB
/
server_udt_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from config import *
from common_tools import *
import socket
import random
from udt4py import UDTSocket
import threading
class ServerUDTManager:
def __init__(self, tcp_mode):
self.tcp_mode = tcp_mode
self.udt_sock = None
self.conn = None
self.sock = self.get_socket()
self.port = self.sock.getsockname()[1]
self.nonce = self.generate_nonce()
self.size = 0
def open_connection(self):
if not self.tcp_mode:
self.udt_sock = UDTSocket()
self.udt_sock.bind(self.sock.fileno())
self.udt_sock.listen()
listening_thread = threading.Thread(target=self.accept_and_verify)
listening_thread.start()
return (self.port, self.nonce)
def get_total_recieved(self):
return self.size
def accept_and_verify(self):
if not self.tcp_mode:
self.conn, addr = self.udt_sock.accept()
logger.info('Connected by %s', addr)
recvd_nonce = bytearray(NONCE_SIZE)
self.conn.recv(recvd_nonce)
recvd_nonce = str(recvd_nonce)
else:
self.conn, addr = self.sock.accept()
logger.info('Connected by %s', addr)
recvd_nonce = self.conn.recv(NONCE_SIZE)
if recvd_nonce != self.nonce:
fail(format("Received nonce %s doesn't match %s.", recvd_nonce, self.nonce))
logger.debug("Nonce verified.")
def receive_data(self, output_file, block_count, file_size):
"""
Receives data and writes it to disk, stops when it is no longer receiving
data.
"""
def receive_data_threaded(output_file, block_count, file_size):
logger.debug("Receiving data...")
output_file = open(output_file, "r+")
output_file.seek(block_count * CHUNK_SIZE)
self.size = block_count * CHUNK_SIZE
data = bytearray(CHUNK_SIZE)
if not self.tcp_mode:
while 1:
len_rec = self.conn.recv(data)
data = str(data)
output_file.write(data[:len_rec])
self.size += len_rec
if len_rec == 0 or str(self.size) == str(file_size):
break
else:
while 1:
data = self.conn.recv(CHUNK_SIZE)
output_file.write(data)
self.size += len(data)
if len(data) == 0:
break
logger.debug("Closing file... " + output_file.name)
output_file.close()
thread = threading.Thread(target=receive_data_threaded, args=(output_file, block_count, file_size))
thread.start()
return thread
def get_socket(self):
"""
Opens and returns a socket on an open port.
"""
s = None
if self.tcp_mode:
sock_type = socket.SOCK_STREAM
else:
sock_type = socket.SOCK_DGRAM
try:
s = socket.socket(socket.AF_INET, sock_type)
except socket.error as msg:
fail(msg)
try:
s.bind(('', 0))
if self.tcp_mode:
s.listen(1)
except socket.error as msg:
s.close()
fail(str(msg))
return s
def generate_nonce(self, length=NONCE_SIZE):
"""Generate pseudorandom number. Ripped from google."""
return ''.join([str(random.randint(0, 9)) for i in range(length)])