diff --git a/tests/container.thrift b/tests/container.thrift index 2ad4697..00e4e4c 100644 --- a/tests/container.thrift +++ b/tests/container.thrift @@ -16,3 +16,22 @@ struct MixItem { 1: optional list> list_map, 2: optional map> map_list, } + +struct BinListStruct { + 1: optional list list_items, +} + +struct BinListItem { + 1: optional list list_binary, + 2: optional list> list_list_binary, +} + +struct BinMapItem { + 1: optional map map_binary, + 2: optional map> map_map_binary, +} + +struct BinMixItem { + 1: optional list> list_map, + 2: optional map> map_list, +} diff --git a/tests/test_all_protocols_binary_field.py b/tests/test_all_protocols_binary_field.py index 03c604b..116b029 100644 --- a/tests/test_all_protocols_binary_field.py +++ b/tests/test_all_protocols_binary_field.py @@ -279,7 +279,7 @@ def test_complex_map(): spec=(TType.STRING, TType.BINARY)) b2.flush() - assert b1.getvalue() != b2.getvalue() + assert b1.getvalue() == b2.getvalue() type_map = { diff --git a/tests/test_protocol_binary.py b/tests/test_protocol_binary.py index d0f07c4..21a494e 100644 --- a/tests/test_protocol_binary.py +++ b/tests/test_protocol_binary.py @@ -6,6 +6,8 @@ from thriftpy2.thrift import TType, TPayload from thriftpy2.utils import hexlify from thriftpy2.protocol import binary as proto +from thriftpy2 import load +from thriftpy2.utils import serialize class TItem(TPayload): @@ -160,3 +162,63 @@ def test_write_huge_struct(): b = BytesIO() item = TItem(id=12345, phones=["1234567890"] * 100000) proto.TBinaryProtocol(b).write_struct(item) + + +def test_string_binary_equivalency(): + from thriftpy2.protocol.binary import TBinaryProtocolFactory + from thriftpy2.protocol.cybin import TCyBinaryProtocolFactory + string_binary_equivalency(TBinaryProtocolFactory) + string_binary_equivalency(TCyBinaryProtocolFactory) + + +def string_binary_equivalency(proto_factory): + container = load("./container.thrift") + l_item = container.ListItem() + l_item.list_string = ['foo', 'bar'] + l_item.list_list_string = [['foo', 'bar']] + + bl_item = container.BinListItem() + bl_item.list_binary = ['foo', 'bar'] + bl_item.list_list_binary = [['foo', 'bar']] + + assert serialize(l_item, proto_factory=proto_factory()) == serialize( + l_item, proto_factory=proto_factory()) + + m_item = container.MapItem() + m_item.map_string = {'foo': 'bar'} + m_item.map_map_string = {'foo': {'hello': 'world'}} + + bm_item = container.BinMapItem() + bm_item.map_binary = {'foo': 'bar'} + bm_item.map_map_binary = {'foo': {'hello': 'world'}} + + assert serialize(m_item, proto_factory=proto_factory()) == serialize( + bm_item, proto_factory=proto_factory()) + + x_item = container.MixItem() + x_item.list_map = [{'foo': 'bar'}] + x_item.map_list = {'foo': ['hello', 'world']} + + bx_item = container.BinMixItem() + bx_item.list_map = [{'foo': 'bar'}] + bx_item.map_list = {'foo': ['hello', 'world']} + + assert serialize(x_item, proto_factory=proto_factory()) == serialize( + bx_item, proto_factory=proto_factory()) + + l_item = container.ListItem() + l_item.list_string = ['foo', 'bar'] * 100 + l_item.list_list_string = [['foo', 'bar']] * 100 + + l_struct = container.ListStruct() + l_struct.list_items = [l_item] * 100 + + bl_item = container.BinListItem() + bl_item.list_binary = ['foo', 'bar'] * 100 + bl_item.list_list_binary = [['foo', 'bar']] * 100 + + bl_struct = container.BinListStruct() + bl_struct.list_items = [l_item] * 100 + + assert serialize(l_struct, proto_factory=proto_factory()) == serialize( + bl_struct, proto_factory=proto_factory()) diff --git a/tests/test_tornado.py b/tests/test_tornado.py index 7abbd86..1ad603f 100644 --- a/tests/test_tornado.py +++ b/tests/test_tornado.py @@ -1,173 +1,174 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import - import sys -from os import path -import logging -import socket - -import pytest -from tornado import gen, testing - -import thriftpy2 -from thriftpy2.tornado import make_client -from thriftpy2.tornado import make_server -from thriftpy2.transport import TTransportException - - -logging.basicConfig(level=logging.INFO) - -addressbook = thriftpy2.load(path.join(path.dirname(__file__), - "addressbook.thrift")) - - -class Dispatcher(object): - def __init__(self, io_loop): - self.io_loop = io_loop - self.registry = {} - - def add(self, person): - """ - bool add(1: Person person); - """ - if person.name in self.registry: - return False - self.registry[person.name] = person - return True - - def get(self, name): - """ - Person get(1: string name) throws (1: PersonNotExistsError not_exists); - """ - if not name: - # undeclared exception - raise ValueError('name cannot be empty') - if name not in self.registry: - raise addressbook.PersonNotExistsError( - 'Person "{}" does not exist!'.format(name)) - return self.registry[name] - - @gen.coroutine - def remove(self, name): - """ - bool remove(1: string name) throws (1: PersonNotExistsError not_exists) - """ - # delay action for later - yield gen.Task(self.io_loop.add_callback) - if not name: - # undeclared exception - raise ValueError('name cannot be empty') - if name not in self.registry: - raise addressbook.PersonNotExistsError( - 'Person "{}" does not exist!'.format(name)) - del self.registry[name] - raise gen.Return(True) - - -class TornadoRPCTestCase(testing.AsyncTestCase): - def mk_server(self): - server = make_server(addressbook.AddressBookService, - Dispatcher(self.io_loop), - io_loop=self.io_loop) - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('localhost', 0)) - sock.setblocking(0) - sock.listen(128) - - server.add_socket(sock) - self.port = sock.getsockname()[-1] - return server - - def mk_client(self): - return make_client(addressbook.AddressBookService, - '127.0.0.1', self.port, io_loop=self.io_loop) - - def mk_client_with_url(self): - return make_client(addressbook.AddressBookService, - io_loop=self.io_loop, - url='thrift://127.0.0.1:{port}'.format( - port=self.port)) - - def setUp(self): - super(TornadoRPCTestCase, self).setUp() - self.server = self.mk_server() - self.client = self.io_loop.run_sync(self.mk_client) - self.client_with_url = self.io_loop.run_sync(self.mk_client_with_url) - - def tearDown(self): - self.server.stop() - self.client.close() - self.client_with_url.close() - super(TornadoRPCTestCase, self).tearDown() - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_make_client(self): - linus = addressbook.Person(name='Linus Torvalds') - success = yield self.client_with_url.add(linus) - assert success - success = yield self.client.add(linus) - assert not success - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_synchronous_result(self): - dennis = addressbook.Person(name='Dennis Ritchie') - success = yield self.client.add(dennis) - assert success - success = yield self.client.add(dennis) - assert not success - person = yield self.client.get(dennis.name) - assert person.name == dennis.name - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_synchronous_exception(self): - exc = None - try: - yield self.client.get('Brian Kernighan') - except Exception as e: - exc = e - - assert isinstance(exc, addressbook.PersonNotExistsError) - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_synchronous_undeclared_exception(self): - exc = None - try: - yield self.client.get('') - except Exception as e: - exc = e - - assert isinstance(exc, TTransportException) - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_asynchronous_result(self): - dennis = addressbook.Person(name='Dennis Ritchie') - yield self.client.add(dennis) - success = yield self.client.remove(dennis.name) - assert success - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_asynchronous_exception(self): - exc = None - try: - yield self.client.remove('Brian Kernighan') - except Exception as e: - exc = e - assert isinstance(exc, addressbook.PersonNotExistsError) - - @testing.gen_test - @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") - def test_asynchronous_undeclared_exception(self): - exc = None - try: - yield self.client.remove('') - except Exception as e: - exc = e - assert isinstance(exc, TTransportException) + +if sys.version_info[0] == 3 and sys.version_info[1] >= 10: + pass +else: + from os import path + import logging + import socket + + import pytest + from tornado import gen, testing + + import thriftpy2 + from thriftpy2.tornado import make_client + from thriftpy2.tornado import make_server + from thriftpy2.transport import TTransportException + + logging.basicConfig(level=logging.INFO) + + addressbook = thriftpy2.load(path.join(path.dirname(__file__), + "addressbook.thrift")) + + class Dispatcher(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.registry = {} + + def add(self, person): + """ + bool add(1: Person person); + """ + if person.name in self.registry: + return False + self.registry[person.name] = person + return True + + def get(self, name): + """ + Person get(1: string name) throws (1: PersonNotExistsError not_exists); + """ + if not name: + # undeclared exception + raise ValueError('name cannot be empty') + if name not in self.registry: + raise addressbook.PersonNotExistsError( + 'Person "{}" does not exist!'.format(name)) + return self.registry[name] + + @gen.coroutine + def remove(self, name): + """ + bool remove(1: string name) throws (1: PersonNotExistsError not_exists) + """ + # delay action for later + yield gen.Task(self.io_loop.add_callback) + if not name: + # undeclared exception + raise ValueError('name cannot be empty') + if name not in self.registry: + raise addressbook.PersonNotExistsError( + 'Person "{}" does not exist!'.format(name)) + del self.registry[name] + raise gen.Return(True) + + class TornadoRPCTestCase(testing.AsyncTestCase): + def mk_server(self): + server = make_server(addressbook.AddressBookService, + Dispatcher(self.io_loop), + io_loop=self.io_loop) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('localhost', 0)) + sock.setblocking(0) + sock.listen(128) + + server.add_socket(sock) + self.port = sock.getsockname()[-1] + return server + + def mk_client(self): + return make_client(addressbook.AddressBookService, + '127.0.0.1', self.port, io_loop=self.io_loop) + + def mk_client_with_url(self): + return make_client(addressbook.AddressBookService, + io_loop=self.io_loop, + url='thrift://127.0.0.1:{port}'.format( + port=self.port)) + + def setUp(self): + super(TornadoRPCTestCase, self).setUp() + self.server = self.mk_server() + self.client = self.io_loop.run_sync(self.mk_client) + self.client_with_url = self.io_loop.run_sync( + self.mk_client_with_url) + + def tearDown(self): + self.server.stop() + self.client.close() + self.client_with_url.close() + super(TornadoRPCTestCase, self).tearDown() + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_make_client(self): + linus = addressbook.Person(name='Linus Torvalds') + success = yield self.client_with_url.add(linus) + assert success + success = yield self.client.add(linus) + assert not success + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_synchronous_result(self): + dennis = addressbook.Person(name='Dennis Ritchie') + success = yield self.client.add(dennis) + assert success + success = yield self.client.add(dennis) + assert not success + person = yield self.client.get(dennis.name) + assert person.name == dennis.name + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_synchronous_exception(self): + exc = None + try: + yield self.client.get('Brian Kernighan') + except Exception as e: + exc = e + + assert isinstance(exc, addressbook.PersonNotExistsError) + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_synchronous_undeclared_exception(self): + exc = None + try: + yield self.client.get('') + except Exception as e: + exc = e + + assert isinstance(exc, TTransportException) + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_asynchronous_result(self): + dennis = addressbook.Person(name='Dennis Ritchie') + yield self.client.add(dennis) + success = yield self.client.remove(dennis.name) + assert success + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_asynchronous_exception(self): + exc = None + try: + yield self.client.remove('Brian Kernighan') + except Exception as e: + exc = e + assert isinstance(exc, addressbook.PersonNotExistsError) + + @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") + def test_asynchronous_undeclared_exception(self): + exc = None + try: + yield self.client.remove('') + except Exception as e: + exc = e + assert isinstance(exc, TTransportException) diff --git a/thriftpy2/protocol/binary.py b/thriftpy2/protocol/binary.py index 9ca7a29..7b43160 100644 --- a/thriftpy2/protocol/binary.py +++ b/thriftpy2/protocol/binary.py @@ -83,10 +83,16 @@ def write_field_stop(outbuf): def write_list_begin(outbuf, etype, size): + if etype == TType.BINARY: + etype = TType.STRING outbuf.write(pack_i8(etype) + pack_i32(size)) def write_map_begin(outbuf, ktype, vtype, size): + if ktype == TType.BINARY: + ktype = TType.STRING + if vtype == TType.BINARY: + vtype = TType.STRING outbuf.write(pack_i8(ktype) + pack_i8(vtype) + pack_i32(size)) diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index f789a93..14a3839 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -115,6 +115,9 @@ cdef inline write_list(CyTransportBase buf, object val, spec): e_type = spec[0] e_spec = spec[1] + if e_type == T_BINARY: + e_type = T_STRING + val_len = len(val) write_i08(buf, e_type) write_i32(buf, val_len) @@ -142,6 +145,9 @@ cdef inline write_dict(CyTransportBase buf, object val, spec): k_type = key[0] k_spec = key[1] + if k_type == T_BINARY: + k_type = T_STRING + value = spec[1] if isinstance(value, int): v_type = value @@ -150,6 +156,9 @@ cdef inline write_dict(CyTransportBase buf, object val, spec): v_type = value[0] v_spec = value[1] + if v_type == T_BINARY: + v_type = T_STRING + val_len = len(val) write_i08(buf, k_type)