From 0bbc551975a1f4337594d9bab351442d937fc639 Mon Sep 17 00:00:00 2001 From: Jonathan Mackenzie Date: Wed, 10 Feb 2021 15:07:05 +1030 Subject: [PATCH] Added binary support --- tests/bin_test.thrift | 15 ++ tests/test_all_protocols_binary_field.py | 171 ++++++++++++++++++++++- thriftpy2/protocol/apache_json.py | 23 ++- thriftpy2/protocol/compact.py | 5 + thriftpy2/protocol/cybin/cybin.pyx | 8 +- thriftpy2/protocol/json.py | 5 + 6 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 tests/bin_test.thrift diff --git a/tests/bin_test.thrift b/tests/bin_test.thrift new file mode 100644 index 00000000..be5df50d --- /dev/null +++ b/tests/bin_test.thrift @@ -0,0 +1,15 @@ +struct BinTest { + 1: binary tbinary, + 2: map str2bin, + 3: map bin2bin, + 4: map bin2str, + 5: list binlist, + 6: set binset, + 7: map> map_of_str2binlist, + 8: map> map_of_bin2bin, + 9: optional list> list_of_bin2str +} +service BinService { + // Testing Service that just returns what you give it + BinTest test(1: BinTest test); +} diff --git a/tests/test_all_protocols_binary_field.py b/tests/test_all_protocols_binary_field.py index d09b44fe..14e50680 100644 --- a/tests/test_all_protocols_binary_field.py +++ b/tests/test_all_protocols_binary_field.py @@ -1,23 +1,27 @@ from __future__ import absolute_import import time +import traceback from multiprocessing import Process import pytest import six +from thriftpy2.thrift import TType, TPayloadMeta +from thriftpy2.protocol import cybin import thriftpy2 -from thriftpy2.http import make_server as make_http_server, \ - make_client as make_http_client +from thriftpy2.http import ( + make_server as make_http_server, + make_client as make_http_client, +) from thriftpy2.protocol import ( TApacheJSONProtocolFactory, TJSONProtocolFactory, - TCompactProtocolFactory + TCompactProtocolFactory, ) from thriftpy2.protocol import TBinaryProtocolFactory -from thriftpy2.rpc import make_server as make_rpc_server, \ - make_client as make_rpc_client -from thriftpy2.transport import TBufferedTransportFactory +from thriftpy2.rpc import make_server as make_rpc_server, make_client as make_rpc_client +from thriftpy2.transport import TBufferedTransportFactory, TCyMemoryBuffer protocols = [TApacheJSONProtocolFactory, TJSONProtocolFactory, @@ -130,6 +134,7 @@ def run_server(): res = client.test(test_object) assert recursive_vars(res) == recursive_vars(test_object) except Exception as e: + traceback.print_exc() err = e finally: proc.terminate() @@ -179,3 +184,157 @@ def do_server(): proc.terminate() time.sleep(1) + + +@pytest.mark.parametrize('proto_factory', protocols) +def test_complex_binary(proto_factory): + + spec = thriftpy2.load("bin_test.thrift", module_name="bin_thrift") + bin_test_obj = spec.BinTest( + tbinary=b'\x01\x0f\xffa binary string\x0f\xee', + str2bin={ + 'key': 'value', + 'foo': 'bar' + }, + bin2bin={ + b'bin_key': b'bin_val', + 'str2bytes': b'bin bar' + }, + bin2str={ + b'bin key': 'str val', + }, + binlist=[b'bin one', b'bin two', 'str should become bin'], + binset={b'val 1', b'foo', b'bar', b'baz'}, + map_of_str2binlist={ + 'key1': [b'bin 1', b'pop 2'] + }, + map_of_bin2bin={ + b'abc': { + b'def': b'val', + b'\x1a\x04': b'\x45' + } + }, + list_of_bin2str=[ + { + b'bin key': 'str val', + b'other key\x04': 'bob' + } + ] + ) + + class Handler(object): + @staticmethod + def test(t): + return t + + trans_factory = TBufferedTransportFactory + + def run_server(): + server = make_rpc_server( + spec.BinService, + handler=Handler(), + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + server.serve() + + proc = Process(target=run_server) + proc.start() + time.sleep(0.2) + + try: + client = make_rpc_client( + spec.BinService, + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + res = client.test(bin_test_obj) + check_types(spec.BinTest.thrift_spec, res) + finally: + proc.terminate() + time.sleep(0.2) + + +def test_complex_map(): + """ + Test from #156 + """ + proto = cybin + b1 = TCyMemoryBuffer() + proto.write_val(b1, TType.MAP, {"hello": "1"}, + spec=(TType.STRING, TType.STRING)) + b1.flush() + + b2 = TCyMemoryBuffer() + proto.write_val(b2, TType.MAP, {"hello": b"1"}, + spec=(TType.STRING, TType.BINARY)) + b2.flush() + + assert b1.getvalue() != b2.getvalue() + + +type_map = { + TType.BYTE: (int,), + TType.I16: (int,), + TType.I32: (int,), + TType.I64: (int,), + TType.DOUBLE: (float,), + TType.STRING: six.string_types, + TType.BOOL: (bool,), + TType.STRUCT: TPayloadMeta, + TType.SET: (set, list), + TType.LIST: (list,), + TType.MAP: (dict,), + TType.BINARY: six.binary_type +} + +type_names = { + TType.BYTE: "Byte", + TType.I16: "I16", + TType.I32: "I32", + TType.I64: "I64", + TType.DOUBLE: "Double", + TType.STRING: "String", + TType.BOOL: "Bool", + TType.STRUCT: "Struct", + TType.SET: "Set", + TType.LIST: "List", + TType.MAP: "Map", + TType.BINARY: "Binary" +} + + +def check_types(spec, val): + """ + This function should check if a given thrift object matches + a thrift spec + Nb. This function isn't complete + + """ + if isinstance(spec, int): + assert isinstance(val, type_map.get(spec)) + elif isinstance(spec, tuple): + if len(spec) >= 2: + if spec[0] in (TType.LIST, TType.SET): + for item in val: + check_types(spec[1], item) + else: + for i in spec.values(): + t, field_name, to_type = i[:3] + value = getattr(val, field_name) + assert isinstance(value, type_map.get(t)), \ + "Field {} expected {} got {}".format(field_name, type_names.get(t), type(value)) + if to_type: + if t in (TType.SET, TType.LIST): + for _val in value: + check_types(to_type, _val) + elif t == TType.MAP: + for _key, _val in value.items(): + check_types(to_type[0], _key) + check_types(to_type[1], _val) + elif t == TType.STRUCT: + check_types(to_type, value) diff --git a/thriftpy2/protocol/apache_json.py b/thriftpy2/protocol/apache_json.py index 2f6816cb..185da857 100644 --- a/thriftpy2/protocol/apache_json.py +++ b/thriftpy2/protocol/apache_json.py @@ -8,8 +8,9 @@ from __future__ import absolute_import import json import base64 +import sys -from six import string_types +import six from thriftpy2.protocol import TProtocolBase from thriftpy2.thrift import TType @@ -56,6 +57,16 @@ def flatten(suitable_for_isinstance): return tuple(types) +def _ensure_b64_encode(val): + """ + Ensure that the variable is something that we can encode with b64encode + python3 needs bytes, python2 needs string + """ + if sys.version_info[0] > 2 and isinstance(val, str): + return val.encode() + return val + + class TApacheJSONProtocolFactory(object): @staticmethod def get_protocol(trans): @@ -165,15 +176,15 @@ def _thrift_to_dict(self, thrift_obj, item_type=None): self._thrift_to_dict(k, key_type): self._thrift_to_dict(v, to_type[1]) for k, v in thrift_obj.items() }] - if (to_type == TType.BINARY or item_type[0] == TType.BINARY) and TType.BINARY != TType.STRING: - return base64.b64encode(thrift_obj).decode('ascii') + if (to_type == TType.BINARY or item_type[-1] == TType.BINARY) and TType.BINARY != TType.STRING: + return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode('ascii') if isinstance(thrift_obj, bool): return int(thrift_obj) if ( item_type == TType.BINARY or (isinstance(item_type, tuple) and item_type[0] == TType.BINARY) ) and TType.BINARY != TType.STRING: - return base64.b64encode(thrift_obj).decode("ascii") + return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode("ascii") return thrift_obj result = {} for field_idx, thrift_spec in thrift_obj.thrift_spec.items(): @@ -202,7 +213,7 @@ def _thrift_to_dict(self, thrift_obj, item_type=None): } elif ttype == TType.BINARY and TType.BINARY != TType.STRING: result[field_idx] = { - CTYPES[ttype]: base64.b64encode(val).decode('ascii') + CTYPES[ttype]: base64.b64encode(_ensure_b64_encode(val)).decode('ascii') } elif ttype == TType.BOOL: result[field_idx] = { @@ -223,7 +234,7 @@ def _dict_to_thrift(self, data, base_type): :return: """ # if the result is a python type, return it: - if isinstance(data, (str, int, float, bool, bytes, string_types)) or data is None: + if isinstance(data, (str, int, float, bool, six.string_types, six.binary_type)) or data is None: if base_type in (TType.I08, TType.I16, TType.I32, TType.I64): return int(data) if base_type == TType.BINARY and TType.BINARY != TType.STRING: diff --git a/thriftpy2/protocol/compact.py b/thriftpy2/protocol/compact.py index dfe92e5b..7c567b3f 100644 --- a/thriftpy2/protocol/compact.py +++ b/thriftpy2/protocol/compact.py @@ -3,8 +3,11 @@ from __future__ import absolute_import import array +import sys from struct import pack, unpack +import six + from .exc import TProtocolException from .base import TProtocolBase from ..thrift import TException @@ -438,6 +441,8 @@ def _write_double(self, dub): def _write_binary(self, b): self._write_size(len(b)) + if isinstance(b, six.string_types) and sys.version_info[0] > 2: + b = b.encode() self.trans.write(b) def _write_string(self, s): diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index 946c2d48..bb27296c 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -2,6 +2,8 @@ from libc.stdlib cimport free, malloc from libc.stdint cimport int16_t, int32_t, int64_t from cpython cimport bool +import six + from thriftpy2.transport.cybase cimport CyTransportBase, STACK_STRING_LEN from ..thrift import TDecodeException @@ -215,7 +217,7 @@ cdef inline write_struct(CyTransportBase buf, obj): write_i16(buf, fid) try: c_write_val(buf, f_type, v, container_spec) - except (TypeError, AttributeError, AssertionError, OverflowError): + except (TypeError, AttributeError, AssertionError, OverflowError) as e: raise TDecodeException(obj.__class__.__name__, fid, f_name, v, f_type, container_spec) @@ -357,10 +359,12 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None): write_double(buf, val) elif ttype == T_BINARY: + if isinstance(val, six.string_types): + val = val.encode() write_string(buf, val) elif ttype == T_STRING: - if not isinstance(val, bytes): + if not isinstance(val, six.binary_type): try: val = val.encode("utf-8") except Exception: diff --git a/thriftpy2/protocol/json.py b/thriftpy2/protocol/json.py index fe056add..89a57bd3 100644 --- a/thriftpy2/protocol/json.py +++ b/thriftpy2/protocol/json.py @@ -5,8 +5,11 @@ import json import struct import base64 +import sys from warnings import warn +import six + from thriftpy2._compat import u from thriftpy2.thrift import TType @@ -17,6 +20,8 @@ def encode_binary(data): + if isinstance(data, six.string_types) and sys.version_info[0] > 2: + data = data.encode() return base64.b64encode(data).decode('ascii')