Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support THttpClientTransport #58

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions thrift_connector/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import threading
import time
import socket
import urllib

from collections import deque

from thrift.transport.THttpClient import THttpClient

from .hooks import api_call_context, client_get_hook


Expand Down Expand Up @@ -188,6 +191,66 @@ def set_timeout(cls, socket, timeout):
def get_timeout(self):
return self.socket._timeout

class ThriftHttpClient(ThriftBaseClient):
@property
def TTransportException(self):
from thrift.transport.TTransport import TTransportException
return TTransportException

@classmethod
def connect(cls, service, host, port, timeout=30, keepalive=None,
pool_generation=0, tracking=False, tracker_factory=None,
pool=None, use_limit=None, headers=None):
SOCKET = cls.get_socket_factory()(host, port, headers=headers)
cls.set_timeout(SOCKET, timeout * 1000)
PROTO_FACTORY = cls.get_protoco_factory()
TRANS_FACTORY = cls.get_transport_factory()

transport = TRANS_FACTORY(SOCKET)
protocol = PROTO_FACTORY(transport)

transport.open()

return cls(
host=host,
port=port,
transport=transport,
protocol=protocol,
service=service,
keepalive=keepalive,
pool_generation=pool_generation,
tracking=tracking,
tracker_factory=tracker_factory,
pool=pool,
socket=SOCKET,
use_limit=use_limit,
)

@classmethod
def get_socket_factory(self):
return THttpEntity

@classmethod
def get_transport_factory(self):
return THttpClientTransportAdapter

@classmethod
def get_protoco_factory(self):
from thrift.protocol import TBinaryProtocol
return TBinaryProtocol.TBinaryProtocolAccelerated

@classmethod
def set_timeout(cls, socket, timeout):
socket.setTimeout(timeout)

def get_timeout(self):
return self.socket._timeout

def get_tclient(self, service, protocol):
if self.tracking is True:
raise NotImplementedError(
"%s doesn't support tracking" % self.__class__.__name__)
return service.Client(protocol)

class ThriftPyBaseClient(ThriftBaseClient):
def init_client(self, client):
Expand Down Expand Up @@ -458,6 +521,67 @@ def fill_connection_pool(self):
def yield_server(self):
return self.host, self.port

class ThriftHTTPClientPool(BaseClientPool):
def __init__(self, service, host, port=None, path=None, scheme='http', headers=None, timeout=300, name=None,
raise_empty=False, max_conn=30, connection_class=ThriftHttpClient,
keepalive=None, tracking=False, tracker_factory=None,
use_limit=None):
super(ThriftHTTPClientPool, self).__init__(
service=service,
timeout=timeout,
name=name,
raise_empty=raise_empty,
max_conn=max_conn,
connection_class=connection_class,
keepalive=keepalive,
tracking=tracking,
tracker_factory=tracker_factory,
use_limit=use_limit,
)
if port is not None:
self.host = str(host).replace('http://','').replace('https://','')
self.port = port
self.path = path
self.scheme = scheme
else:
parsed = urllib.parse.urlparse(host)
self.scheme = parsed.scheme
assert self.scheme in ('http', 'https')
if self.scheme == 'http':
self.port = parsed.port or '80'
elif self.scheme == 'https':
self.port = parsed.port or '443'
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
self.path += '?%s' % parsed.query

self.headers = headers

def produce_client(self, host=None, port=None, headers=None):
if host is None and port is None:
if headers is None:
host, port, headers = self.yield_server()
else:
host, port, aaa = self.yield_server()
elif not all((host, port)):
raise ValueError("host and port should be 'both none' \
or 'both provided' ")
return self.connection_class.connect(
self.service,
host,
port,
self.timeout,
keepalive=self.keepalive,
pool_generation=self.generation,
tracking=self.tracking,
tracker_factory=self.tracker_factory,
pool=self,
use_limit=self.use_limit, headers=headers
)

def yield_server(self):
return self.host, self.port, self.headers

class HeartbeatClientPool(ClientPool):

Expand Down Expand Up @@ -553,3 +677,20 @@ def yield_server(self):
ret = self.servers[self.index]
self.index += 1
return ret

class THttpEntity(object):
def __init__(self, uri_or_host, port=None, path=None, scheme='http', headers=None):
self.host = uri_or_host
self.port = port
self.path = path
self.scheme = scheme
self.headers = headers

def setTimeout(self, timeout):
self._timeout = timeout

class THttpClientTransportAdapter(THttpClient):
def __init__(self, entity):
super(THttpClientTransportAdapter, self).__init__(entity.scheme + "://" + entity.host + ":" + str(entity.port) + ("" if entity.path == None or entity.path == "" else entity.path))
if entity.headers and isinstance(entity.headers, dict):
self.setCustomHeaders(entity.headers)