diff --git a/fakenet/listeners/ProxyListener.py b/fakenet/listeners/ProxyListener.py index 33b193dd..6f0e5ddc 100644 --- a/fakenet/listeners/ProxyListener.py +++ b/fakenet/listeners/ProxyListener.py @@ -9,6 +9,7 @@ import select import logging import ssl +import traceback from OpenSSL import SSL from ssl_utils import ssl_detector from . import * @@ -21,10 +22,10 @@ class ProxyListener(object): def __init__( - self, - config={}, - name ='ProxyListener', - logging_level=logging.DEBUG, + self, + config={}, + name ='ProxyListener', + logging_level=logging.DEBUG, ): self.logger = logging.getLogger(name) @@ -50,14 +51,14 @@ def start(self): self.logger.debug('Starting TCP ...') - self.server = ThreadedTCPServer((IP, + self.server = ThreadedTCPServer((IP, int(self.config.get('port'))), ThreadedTCPRequestHandler) - + elif proto == 'UDP': self.logger.debug('Starting UDP ...') - self.server = ThreadedUDPServer((IP, + self.server = ThreadedUDPServer((IP, int(self.config.get('port'))), ThreadedUDPRequestHandler) self.server.fwd_table = self.udp_fwd_table @@ -68,7 +69,7 @@ def start(self): else: self.logger.error('Protocol is not defined') return - + self.server.config = self.config self.server.logger = self.logger self.server.running_listeners = None @@ -78,7 +79,7 @@ def start(self): self.server_thread.daemon = True self.server_thread.start() server_ip, server_port = self.server.server_address - self.logger.info("%s Server(%s:%d) thread: %s" % (proto, server_ip, + self.logger.info("%s Server(%s:%d) thread: %s" % (proto, server_ip, server_port, self.server_thread.name)) def stop(self): @@ -92,40 +93,6 @@ def acceptListeners(self, listeners): def acceptDiverter(self, diverter): self.server.diverter = diverter - -class ThreadedTCPClientSocket(threading.Thread): - - - def __init__(self, ip, port, listener_q, remote_q, config, log): - - super(ThreadedTCPClientSocket, self).__init__() - self.ip = ip - self.port = int(port) - self.listener_q = listener_q - self.remote_q = remote_q - self.config = config - self.logger = log - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - def run(self): - - try: - self.sock.connect((self.ip, self.port)) - while True: - readable, writable, exceptional = select.select([self.sock], - [], [], .001) - if not self.remote_q.empty(): - data = self.remote_q.get() - self.sock.send(data) - if readable: - data = self.sock.recv(BUF_SZ) - if data: - self.listener_q.put(data) - else: - self.sock.close() - exit(1) - except Exception as e: - self.logger.debug('Listener socket exception %s' % e.message) class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): daemon_threads = True @@ -133,16 +100,16 @@ class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): class ThreadedUDPServer(SocketServer.ThreadingMixIn, SocketServer.UDPServer): daemon_threads = True -def get_top_listener(config, data, listeners, diverter, orig_src_ip, +def get_top_listener(config, data, listeners, diverter, orig_src_ip, orig_src_port, proto): - + top_listener = None top_confidence = 0 dport = diverter.getOriginalDestPort(orig_src_ip, orig_src_port, proto) for listener in listeners: - + try: confidence = listener.taste(data, dport) if confidence > top_confidence: @@ -151,20 +118,23 @@ def get_top_listener(config, data, listeners, diverter, orig_src_ip, except: # Exception occurs if taste() is not implemented for this listener pass - + return top_listener class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): - + def handle(self): + self.timeout = 3 + self.select_timeout = 0.001 + self.is_running = True + remote_sock = self.request # queue for data received from the listener listener_q = Queue.Queue() # queue for data received from remote remote_q = Queue.Queue() - data = None ssl_remote_sock = None @@ -180,80 +150,119 @@ def handle(self): self.logger.error('Could not locate %s', certfile_path) sys.exit(1) - ssl_version = ssl.PROTOCOL_SSLv23 try: data = remote_sock.recv(BUF_SZ, socket.MSG_PEEK) + log_details(self.server.logger, data) + except Exception as e: + data = '' + self.server.logger.info('recv() error: %s' % e.message) - self.server.logger.info('Received %d bytes.', len(data)) - self.server.logger.debug('%s', '-'*80,) - for line in hexdump_table(data): - self.server.logger.debug(line) - self.server.logger.debug('%s', '-'*80,) + if not data: + return + + + if ssl_detector.looks_like_ssl(data): + self.server.logger.debug('SSL detected') + ssl_remote_sock = ssl.wrap_socket( + remote_sock, + server_side=True, + do_handshake_on_connect=True, + certfile=certfile_path, + ssl_version=ssl.PROTOCOL_SSLv23, + keyfile=keyfile_path ) + data = ssl_remote_sock.recv(BUF_SZ) + + top_listener = get_top_listener(self.server.config, data, + self.server.listeners, self.server.diverter, + self.client_address[0], self.client_address[1], 'TCP') + if not top_listener: + return + + self.server.logger.debug('Likely listener: %s' % top_listener.name) + remote_sock.setblocking(0) + + # ssl has no 'peek' option, so we need to process the first + # packet that is already consumed from the socket + if ssl_remote_sock: + ssl_remote_sock.setblocking(0) + remote_q.put(data) + + data_available = threading.Event() + + # Try to connect to listener socket + lsocket = self.connect_to_listener('localhost', int(top_listener.port)) + if lsocket is None: + self.server.logger.error('Failed to connect to listener socket') + return + + threading.Thread(target=self.receive_data, args=[ + lsocket, # listener socket + listener_q, # data queue + data_available, # event + ]).start() + + rsocket = remote_sock if ssl_remote_sock is None else ssl_remote_sock + threading.Thread(target=self.receive_data, args=[ + rsocket, # remote socket + remote_q, # queue + data_available # event + ]).start() + + self.proxy(rsocket, lsocket, remote_q, listener_q, data_available) + lsocket.close() + rsocket.close() + return + + def proxy(self, rsocket, lsocket, rq, lq, ev): + try: + while self.is_running: + while not rq.empty(): + data = rq.get() + lsocket.send(data) + while not lq.empty(): + data = lq.get() + rsocket.send(data) + ev.clear() + if not ev.wait(timeout=self.timeout): + self.is_running = False + break except Exception as e: - self.server.logger.info('recv() error: %s' % e.message) + self.server.logger.error('Failed to proxy data') + self.server.logger.debug(traceback.format_exc()) + return - if data: + def receive_data(self, s, q, ev): + try: + while self.is_running: + readable, writable, exceptional = select.select([s], [], [], self.select_timeout) + if readable: + data = s.recv(BUF_SZ) + if data: + q.put(data, block=True) + ev.set() + else: + s.close() + break + except Exception as e: + self.server.logger.error('Exception when trying to receive data') + self.server.logger.debug(e.message) - if ssl_detector.looks_like_ssl(data): - self.server.logger.debug('SSL detected') - ssl_remote_sock = ssl.wrap_socket( - remote_sock, - server_side=True, - do_handshake_on_connect=True, - certfile=certfile_path, - ssl_version=ssl_version, - keyfile=keyfile_path ) - data = ssl_remote_sock.recv(BUF_SZ) - - orig_src_ip = self.client_address[0] - orig_src_port = self.client_address[1] - - top_listener = get_top_listener(self.server.config, data, - self.server.listeners, self.server.diverter, - orig_src_ip, orig_src_port, 'TCP') + # Always set the last event to have the main loop wake up + ev.set() + return + + def connect_to_listener(self, host, port): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((host, port)) + except Exception as e: + s = None + self.server.logger.error('Failed to connect to listener') + self.server.logger.error(e.message) + return s - if top_listener: - self.server.logger.debug('Likely listener: %s' % - top_listener.name) - listener_sock = ThreadedTCPClientSocket('localhost', - top_listener.port, listener_q, remote_q, - self.server.config, self.server.logger) - listener_sock.daemon = True - listener_sock.start() - remote_sock.setblocking(0) - - # ssl has no 'peek' option, so we need to process the first - # packet that is already consumed from the socket - if ssl_remote_sock: - ssl_remote_sock.setblocking(0) - remote_q.put(data) - - while True: - readable, writable, exceptional = select.select( - [remote_sock], [], [], .001) - if readable: - try: - if ssl_remote_sock: - data = ssl_remote_sock.recv(BUF_SZ) - else: - data = remote_sock.recv(BUF_SZ) - if data: - remote_q.put(data) - else: - self.server.logger.debug( - 'Closing remote socket connection') - return - except Exception as e: - self.server.logger.debug('Remote Connection terminated') - return - if not listener_q.empty(): - data = listener_q.get() - if ssl_remote_sock: - ssl_remote_sock.send(data) - else: - remote_sock.send(data) class ThreadedUDPRequestHandler(SocketServer.BaseRequestHandler): @@ -262,22 +271,19 @@ def handle(self): data = self.request[0] remote_sock = self.request[1] - self.server.logger.debug('Received UDP packet from %s.' % + self.server.logger.debug('Received UDP packet from %s.' % self.client_address[0]) if data: self.server.logger.info('Received %d bytes.', len(data)) - self.server.logger.debug('%s', '-'*80,) - for line in hexdump_table(data): - self.server.logger.debug(line) - self.server.logger.debug('%s', '-'*80,) + log_details(self.server.logger, data) orig_src_ip = self.client_address[0] orig_src_port = self.client_address[1] - top_listener = get_top_listener(self.server.config, data, - self.server.listeners, self.server.diverter, + top_listener = get_top_listener(self.server.config, data, + self.server.listeners, self.server.diverter, orig_src_ip, orig_src_port, 'UDP') if top_listener: @@ -303,20 +309,30 @@ def hexdump_table(data, length=16): hexdump_lines.append("%04X: %-*s %s" % (i, length*3, hex_line, ascii_line )) return hexdump_lines + +def log_details(logger, data): + logger.info('Received %d bytes.', len(data)) + logger.debug('%s', '-'*80,) + for line in hexdump_table(data): + logger.debug(line) + logger.debug('%s', '-'*80,) + return + + def main(): - logging.basicConfig(format='%(asctime)s [%(name)15s] %(message)s', + logging.basicConfig(format='%(asctime)s [%(name)15s] %(message)s', datefmt='%m/%d/%y %I:%M:%S %p', level=logging.DEBUG) global listeners listeners = load_plugins() - TCP_server = ThreadedTCPServer((IP, int(sys.argv[1])), + TCP_server = ThreadedTCPServer((IP, int(sys.argv[1])), ThreadedTCPRequestHandler) TCP_server_thread = threading.Thread(target=TCP_server.serve_forever) TCP_server_thread.daemon = True TCP_server_thread.start() tcp_server_ip, tcp_server_port = TCP_server.server_address - logger.info("TCP Server(%s:%d) thread: %s" % (tcp_server_ip, + logger.info("TCP Server(%s:%d) thread: %s" % (tcp_server_ip, tcp_server_port, TCP_server_thread.name)) try: