Skip to content

Commit

Permalink
Added binary support
Browse files Browse the repository at this point in the history
  • Loading branch information
JonnoFTW committed Feb 10, 2021
1 parent d4bd284 commit 0bbc551
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 14 deletions.
15 changes: 15 additions & 0 deletions tests/bin_test.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
struct BinTest {
1: binary tbinary,
2: map<string, binary> str2bin,
3: map<binary, binary> bin2bin,
4: map<binary, string> bin2str,
5: list<binary> binlist,
6: set<binary> binset,
7: map<string, list<binary>> map_of_str2binlist,
8: map<binary, map<binary, binary>> map_of_bin2bin,
9: optional list<map<binary, string>> list_of_bin2str
}
service BinService {
// Testing Service that just returns what you give it
BinTest test(1: BinTest test);
}
171 changes: 165 additions & 6 deletions tests/test_all_protocols_binary_field.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from __future__ import absolute_import

import time
import traceback
from multiprocessing import Process

import pytest
import six

from thriftpy2.thrift import TType, TPayloadMeta
from thriftpy2.protocol import cybin
import thriftpy2
from thriftpy2.http import make_server as make_http_server, \
make_client as make_http_client
from thriftpy2.http import (
make_server as make_http_server,
make_client as make_http_client,
)
from thriftpy2.protocol import (
TApacheJSONProtocolFactory,
TJSONProtocolFactory,
TCompactProtocolFactory
TCompactProtocolFactory,
)
from thriftpy2.protocol import TBinaryProtocolFactory
from thriftpy2.rpc import make_server as make_rpc_server, \
make_client as make_rpc_client
from thriftpy2.transport import TBufferedTransportFactory
from thriftpy2.rpc import make_server as make_rpc_server, make_client as make_rpc_client
from thriftpy2.transport import TBufferedTransportFactory, TCyMemoryBuffer

protocols = [TApacheJSONProtocolFactory,
TJSONProtocolFactory,
Expand Down Expand Up @@ -130,6 +134,7 @@ def run_server():
res = client.test(test_object)
assert recursive_vars(res) == recursive_vars(test_object)
except Exception as e:
traceback.print_exc()
err = e
finally:
proc.terminate()
Expand Down Expand Up @@ -179,3 +184,157 @@ def do_server():

proc.terminate()
time.sleep(1)


@pytest.mark.parametrize('proto_factory', protocols)
def test_complex_binary(proto_factory):

spec = thriftpy2.load("bin_test.thrift", module_name="bin_thrift")
bin_test_obj = spec.BinTest(
tbinary=b'\x01\x0f\xffa binary string\x0f\xee',
str2bin={
'key': 'value',
'foo': 'bar'
},
bin2bin={
b'bin_key': b'bin_val',
'str2bytes': b'bin bar'
},
bin2str={
b'bin key': 'str val',
},
binlist=[b'bin one', b'bin two', 'str should become bin'],
binset={b'val 1', b'foo', b'bar', b'baz'},
map_of_str2binlist={
'key1': [b'bin 1', b'pop 2']
},
map_of_bin2bin={
b'abc': {
b'def': b'val',
b'\x1a\x04': b'\x45'
}
},
list_of_bin2str=[
{
b'bin key': 'str val',
b'other key\x04': 'bob'
}
]
)

class Handler(object):
@staticmethod
def test(t):
return t

trans_factory = TBufferedTransportFactory

def run_server():
server = make_rpc_server(
spec.BinService,
handler=Handler(),
host='localhost',
port=9090,
proto_factory=proto_factory(),
trans_factory=trans_factory(),
)
server.serve()

proc = Process(target=run_server)
proc.start()
time.sleep(0.2)

try:
client = make_rpc_client(
spec.BinService,
host='localhost',
port=9090,
proto_factory=proto_factory(),
trans_factory=trans_factory(),
)
res = client.test(bin_test_obj)
check_types(spec.BinTest.thrift_spec, res)
finally:
proc.terminate()
time.sleep(0.2)


def test_complex_map():
"""
Test from #156
"""
proto = cybin
b1 = TCyMemoryBuffer()
proto.write_val(b1, TType.MAP, {"hello": "1"},
spec=(TType.STRING, TType.STRING))
b1.flush()

b2 = TCyMemoryBuffer()
proto.write_val(b2, TType.MAP, {"hello": b"1"},
spec=(TType.STRING, TType.BINARY))
b2.flush()

assert b1.getvalue() != b2.getvalue()


type_map = {
TType.BYTE: (int,),
TType.I16: (int,),
TType.I32: (int,),
TType.I64: (int,),
TType.DOUBLE: (float,),
TType.STRING: six.string_types,
TType.BOOL: (bool,),
TType.STRUCT: TPayloadMeta,
TType.SET: (set, list),
TType.LIST: (list,),
TType.MAP: (dict,),
TType.BINARY: six.binary_type
}

type_names = {
TType.BYTE: "Byte",
TType.I16: "I16",
TType.I32: "I32",
TType.I64: "I64",
TType.DOUBLE: "Double",
TType.STRING: "String",
TType.BOOL: "Bool",
TType.STRUCT: "Struct",
TType.SET: "Set",
TType.LIST: "List",
TType.MAP: "Map",
TType.BINARY: "Binary"
}


def check_types(spec, val):
"""
This function should check if a given thrift object matches
a thrift spec
Nb. This function isn't complete
"""
if isinstance(spec, int):
assert isinstance(val, type_map.get(spec))
elif isinstance(spec, tuple):
if len(spec) >= 2:
if spec[0] in (TType.LIST, TType.SET):
for item in val:
check_types(spec[1], item)
else:
for i in spec.values():
t, field_name, to_type = i[:3]
value = getattr(val, field_name)
assert isinstance(value, type_map.get(t)), \
"Field {} expected {} got {}".format(field_name, type_names.get(t), type(value))
if to_type:
if t in (TType.SET, TType.LIST):
for _val in value:
check_types(to_type, _val)
elif t == TType.MAP:
for _key, _val in value.items():
check_types(to_type[0], _key)
check_types(to_type[1], _val)
elif t == TType.STRUCT:
check_types(to_type, value)
23 changes: 17 additions & 6 deletions thriftpy2/protocol/apache_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from __future__ import absolute_import
import json
import base64
import sys

from six import string_types
import six

from thriftpy2.protocol import TProtocolBase
from thriftpy2.thrift import TType
Expand Down Expand Up @@ -56,6 +57,16 @@ def flatten(suitable_for_isinstance):
return tuple(types)


def _ensure_b64_encode(val):
"""
Ensure that the variable is something that we can encode with b64encode
python3 needs bytes, python2 needs string
"""
if sys.version_info[0] > 2 and isinstance(val, str):
return val.encode()
return val


class TApacheJSONProtocolFactory(object):
@staticmethod
def get_protocol(trans):
Expand Down Expand Up @@ -165,15 +176,15 @@ def _thrift_to_dict(self, thrift_obj, item_type=None):
self._thrift_to_dict(k, key_type):
self._thrift_to_dict(v, to_type[1]) for k, v in thrift_obj.items()
}]
if (to_type == TType.BINARY or item_type[0] == TType.BINARY) and TType.BINARY != TType.STRING:
return base64.b64encode(thrift_obj).decode('ascii')
if (to_type == TType.BINARY or item_type[-1] == TType.BINARY) and TType.BINARY != TType.STRING:
return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode('ascii')
if isinstance(thrift_obj, bool):
return int(thrift_obj)
if (
item_type == TType.BINARY
or (isinstance(item_type, tuple) and item_type[0] == TType.BINARY)
) and TType.BINARY != TType.STRING:
return base64.b64encode(thrift_obj).decode("ascii")
return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode("ascii")
return thrift_obj
result = {}
for field_idx, thrift_spec in thrift_obj.thrift_spec.items():
Expand Down Expand Up @@ -202,7 +213,7 @@ def _thrift_to_dict(self, thrift_obj, item_type=None):
}
elif ttype == TType.BINARY and TType.BINARY != TType.STRING:
result[field_idx] = {
CTYPES[ttype]: base64.b64encode(val).decode('ascii')
CTYPES[ttype]: base64.b64encode(_ensure_b64_encode(val)).decode('ascii')
}
elif ttype == TType.BOOL:
result[field_idx] = {
Expand All @@ -223,7 +234,7 @@ def _dict_to_thrift(self, data, base_type):
:return:
"""
# if the result is a python type, return it:
if isinstance(data, (str, int, float, bool, bytes, string_types)) or data is None:
if isinstance(data, (str, int, float, bool, six.string_types, six.binary_type)) or data is None:
if base_type in (TType.I08, TType.I16, TType.I32, TType.I64):
return int(data)
if base_type == TType.BINARY and TType.BINARY != TType.STRING:
Expand Down
5 changes: 5 additions & 0 deletions thriftpy2/protocol/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from __future__ import absolute_import

import array
import sys
from struct import pack, unpack

import six

from .exc import TProtocolException
from .base import TProtocolBase
from ..thrift import TException
Expand Down Expand Up @@ -438,6 +441,8 @@ def _write_double(self, dub):

def _write_binary(self, b):
self._write_size(len(b))
if isinstance(b, six.string_types) and sys.version_info[0] > 2:
b = b.encode()
self.trans.write(b)

def _write_string(self, s):
Expand Down
8 changes: 6 additions & 2 deletions thriftpy2/protocol/cybin/cybin.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ from libc.stdlib cimport free, malloc
from libc.stdint cimport int16_t, int32_t, int64_t
from cpython cimport bool

import six

from thriftpy2.transport.cybase cimport CyTransportBase, STACK_STRING_LEN

from ..thrift import TDecodeException
Expand Down Expand Up @@ -215,7 +217,7 @@ cdef inline write_struct(CyTransportBase buf, obj):
write_i16(buf, fid)
try:
c_write_val(buf, f_type, v, container_spec)
except (TypeError, AttributeError, AssertionError, OverflowError):
except (TypeError, AttributeError, AssertionError, OverflowError) as e:
raise TDecodeException(obj.__class__.__name__, fid, f_name, v,
f_type, container_spec)

Expand Down Expand Up @@ -357,10 +359,12 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):
write_double(buf, val)

elif ttype == T_BINARY:
if isinstance(val, six.string_types):
val = val.encode()
write_string(buf, val)

elif ttype == T_STRING:
if not isinstance(val, bytes):
if not isinstance(val, six.binary_type):
try:
val = val.encode("utf-8")
except Exception:
Expand Down
5 changes: 5 additions & 0 deletions thriftpy2/protocol/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import json
import struct
import base64
import sys
from warnings import warn

import six

from thriftpy2._compat import u
from thriftpy2.thrift import TType

Expand All @@ -17,6 +20,8 @@


def encode_binary(data):
if isinstance(data, six.string_types) and sys.version_info[0] > 2:
data = data.encode()
return base64.b64encode(data).decode('ascii')


Expand Down

0 comments on commit 0bbc551

Please sign in to comment.