diff --git a/thrift_connector/connection_pool.py b/thrift_connector/connection_pool.py index 716768f..d0fb866 100644 --- a/thrift_connector/connection_pool.py +++ b/thrift_connector/connection_pool.py @@ -6,9 +6,12 @@ import threading import time import socket +import urllib from collections import deque +from thrift.transport.THttpClient import THttpClient + from .hooks import api_call_context, client_get_hook @@ -188,6 +191,66 @@ def set_timeout(cls, socket, timeout): def get_timeout(self): return self.socket._timeout +class ThriftHttpClient(ThriftBaseClient): + @property + def TTransportException(self): + from thrift.transport.TTransport import TTransportException + return TTransportException + + @classmethod + def connect(cls, service, host, port, timeout=30, keepalive=None, + pool_generation=0, tracking=False, tracker_factory=None, + pool=None, use_limit=None, headers=None): + SOCKET = cls.get_socket_factory()(host, port, headers=headers) + cls.set_timeout(SOCKET, timeout * 1000) + PROTO_FACTORY = cls.get_protoco_factory() + TRANS_FACTORY = cls.get_transport_factory() + + transport = TRANS_FACTORY(SOCKET) + protocol = PROTO_FACTORY(transport) + + transport.open() + + return cls( + host=host, + port=port, + transport=transport, + protocol=protocol, + service=service, + keepalive=keepalive, + pool_generation=pool_generation, + tracking=tracking, + tracker_factory=tracker_factory, + pool=pool, + socket=SOCKET, + use_limit=use_limit, + ) + + @classmethod + def get_socket_factory(self): + return THttpEntity + + @classmethod + def get_transport_factory(self): + return THttpClientTransportAdapter + + @classmethod + def get_protoco_factory(self): + from thrift.protocol import TBinaryProtocol + return TBinaryProtocol.TBinaryProtocolAccelerated + + @classmethod + def set_timeout(cls, socket, timeout): + socket.setTimeout(timeout) + + def get_timeout(self): + return self.socket._timeout + + def get_tclient(self, service, protocol): + if self.tracking is True: + raise NotImplementedError( + "%s doesn't support tracking" % self.__class__.__name__) + return service.Client(protocol) class ThriftPyBaseClient(ThriftBaseClient): def init_client(self, client): @@ -458,6 +521,67 @@ def fill_connection_pool(self): def yield_server(self): return self.host, self.port +class ThriftHTTPClientPool(BaseClientPool): + def __init__(self, service, host, port=None, path=None, scheme='http', headers=None, timeout=300, name=None, + raise_empty=False, max_conn=30, connection_class=ThriftHttpClient, + keepalive=None, tracking=False, tracker_factory=None, + use_limit=None): + super(ThriftHTTPClientPool, self).__init__( + service=service, + timeout=timeout, + name=name, + raise_empty=raise_empty, + max_conn=max_conn, + connection_class=connection_class, + keepalive=keepalive, + tracking=tracking, + tracker_factory=tracker_factory, + use_limit=use_limit, + ) + if port is not None: + self.host = str(host).replace('http://','').replace('https://','') + self.port = port + self.path = path + self.scheme = scheme + else: + parsed = urllib.parse.urlparse(host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or '80' + elif self.scheme == 'https': + self.port = parsed.port or '443' + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += '?%s' % parsed.query + + self.headers = headers + + def produce_client(self, host=None, port=None, headers=None): + if host is None and port is None: + if headers is None: + host, port, headers = self.yield_server() + else: + host, port, aaa = self.yield_server() + elif not all((host, port)): + raise ValueError("host and port should be 'both none' \ + or 'both provided' ") + return self.connection_class.connect( + self.service, + host, + port, + self.timeout, + keepalive=self.keepalive, + pool_generation=self.generation, + tracking=self.tracking, + tracker_factory=self.tracker_factory, + pool=self, + use_limit=self.use_limit, headers=headers + ) + + def yield_server(self): + return self.host, self.port, self.headers class HeartbeatClientPool(ClientPool): @@ -553,3 +677,20 @@ def yield_server(self): ret = self.servers[self.index] self.index += 1 return ret + +class THttpEntity(object): + def __init__(self, uri_or_host, port=None, path=None, scheme='http', headers=None): + self.host = uri_or_host + self.port = port + self.path = path + self.scheme = scheme + self.headers = headers + + def setTimeout(self, timeout): + self._timeout = timeout + +class THttpClientTransportAdapter(THttpClient): + def __init__(self, entity): + super(THttpClientTransportAdapter, self).__init__(entity.scheme + "://" + entity.host + ":" + str(entity.port) + ("" if entity.path == None or entity.path == "" else entity.path)) + if entity.headers and isinstance(entity.headers, dict): + self.setCustomHeaders(entity.headers) \ No newline at end of file