From 6e96ef28c1ecb6e28a49a528bd8a9e8bfb440354 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 29 Oct 2015 14:22:58 -0700 Subject: [PATCH 1/4] Add support for including other Thrift files. Renamed .{services,types,constants} on generated modules to have `__` at start and end to avoid naming conflicts. Removed force flag from load(). The caching behavior is relied upon by the include logic. Users can instantiate new Loaders if they need the result to not be cached. --- .gitignore | 4 + CHANGELOG.rst | 5 + Makefile | 8 +- docs/overview.rst | 85 ++++++- tests/compile/test_compiler.py | 40 +--- tests/compile/test_includes.py | 398 +++++++++++++++++++++++++++++++++ tests/spec/test_service.py | 9 +- tests/test_loader.py | 5 - thriftrw/_buffer.pyx | 4 +- thriftrw/compile/__init__.py | 3 +- thriftrw/compile/compiler.py | 205 ++++++++++++++--- thriftrw/compile/link.py | 6 +- thriftrw/compile/scope.py | 87 ++++++- thriftrw/idl/__init__.py | 4 + thriftrw/idl/ast.py | 14 +- thriftrw/idl/parser.py | 2 +- thriftrw/loader.py | 32 +-- thriftrw/spec/service.pyx | 9 +- 18 files changed, 794 insertions(+), 126 deletions(-) create mode 100644 tests/compile/test_includes.py diff --git a/.gitignore b/.gitignore index c3d48cb..4c6d916 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ env/ _build *.swp node_modules +*.c +*.so +.benchmarks/ +.coverage diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5a08659..c9bdab4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Releases 0.6.0 (unreleased) ------------------ +- ``include`` statements are now supported. For more information, see + :ref:`including-modules`. - Added support for message envelopes. This makes it possible to talk with standard Apache Thrift services and clients. For more information, see :ref:`calling-apache-thrift`. @@ -19,6 +21,9 @@ Releases ``FunctionSpec``. - ``ServiceSpec`` now provides a ``lookup`` method to look up ``FunctionSpecs`` by name. +- Removed the ``force`` option on ``Loader.load``. +- In generated modules, renamed the ``types``, ``constants`` and ``services`` + attributes to ``__types__``, ``__constants__``, and ``__services__``. 0.5.2 (2015-10-19) diff --git a/Makefile b/Makefile index 48d0364..66c6dfd 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test lint docs docsopen clean +.PHONY: test lint docs docsopen clean bootstrap test_args := \ --cov thriftrw \ @@ -29,3 +29,9 @@ clean: find tests thriftrw -name \*.c -delete find tests thriftrw -name \*.so -delete make -C docs clean + +bootstrap: + pip install -r requirements.txt + pip install -r requirements-dev.txt + pip install -r requirements-test.txt + pip install -e . diff --git a/docs/overview.rst b/docs/overview.rst index 4ca6b4c..c5c40a9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -155,21 +155,50 @@ The generated module contains two top-level functions ``dumps`` and ``loads``. If the method name was not recognized or any other payload parsing errors. -.. py:attribute:: services +.. py:attribute:: __services__ Collection of all classes generated for all services defined in the source thrift file. -.. py:attribute:: types + .. versionchanged:: 0.6 + + Renamed from ``services`` to ``__services__``. + +.. py:attribute:: __types__ Collection of all classes generated for all types defined in the source thrift file. -.. py:attribute:: constants + .. versionchanged:: 0.6 + + Renamed from ``types`` to ``__types__``. + +.. py:attribute:: __includes__ + + Collection of modules which were referenced via ``include`` statements in + the generated module. + + .. versionadded:: 0.6 + +.. py:attribute:: __constants__ Mapping of constant name to value for all constants defined in the source thrift file. + .. versionchanged:: 0.6 + + Renamed from ``constants`` to ``__constants__``. + +Includes +~~~~~~~~ + +For an include:: + + include "shared.thrift" + +The generated module will include a top-level attribute ``shared`` which +references the generated module for ``shared.thrift``. + Structs ~~~~~~~ @@ -413,6 +442,56 @@ Thrift Type Primitive Type ``exception`` ``dict`` ============= ============== +.. _including-modules: + +Including other Thrift files +---------------------------- + +Types, services, and constants defined in different Thrift files may be +referenced by using ``include`` statements. Included modules will automatically +be compiled along with the module that included them, and they will be made +available in the generated module with the base name of the included file. + +For example, given:: + + // shared/types.thrift + + struct UUID { + 1: required i64 high + 2: required i64 low + } + +And:: + + // service.thrift + + include "shared/types.thrift" + + struct User { + 1: required types.UUID uuid + } + +You can do the following + +.. code-block:: python + + service = thriftrw.load('service.thrift') + + user_uuid = service.shared.UUID(...) + user = service.User(uuid=user_uuid) + + # ... + +Also note that you can ``load()`` Thrift files that have already been loaded +without extra cost because the result is cached by the system. + +.. code-block:: python + + service = thriftrw.load('service.thrift') + types = thriftrw.load('shared/types.thrift') + + assert service.types is types + .. _calling-apache-thrift: Calling Apache Thrift diff --git a/tests/compile/test_compiler.py b/tests/compile/test_compiler.py index 1f791bc..93c2622 100644 --- a/tests/compile/test_compiler.py +++ b/tests/compile/test_compiler.py @@ -21,41 +21,11 @@ from __future__ import absolute_import, unicode_literals, print_function import pytest -from functools import partial -from thriftrw.idl import Parser -from thriftrw.protocol import BinaryProtocol -from thriftrw.compile import Compiler from thriftrw.errors import ThriftCompilerError @pytest.fixture -def parse(): - return Parser().parse - - -@pytest.fixture -def compile(request): - return partial(Compiler(BinaryProtocol()).compile, request.node.name) - - -@pytest.fixture -def loads(parse, compile): - return (lambda s: compile(parse(s))) - - -def test_include_disallowed(loads): - with pytest.raises(ThriftCompilerError) as exc_info: - loads(''' - namespace py foo - namespace js bar - - include "foo.thrift" - ''') - - assert 'thriftrw does not support including' in str(exc_info) - - def test_unknown_type(loads): with pytest.raises(ThriftCompilerError) as exc_info: loads(''' @@ -118,14 +88,14 @@ def test_services_and_types(loads): 'z': [m.x, m.y], 'x': 42, 'y': 123, - } == m.constants + } == m.__constants__ assert ( - m.types == (m.Foo, m.Bar) or - m.types == (m.Bar, m.Foo) + m.__types__ == (m.Foo, m.Bar) or + m.__types__ == (m.Bar, m.Foo) ) assert ( - m.services == (m.A, m.B) or - m.services == (m.B, m.A) + m.__services__ == (m.A, m.B) or + m.__services__ == (m.B, m.A) ) diff --git a/tests/compile/test_includes.py b/tests/compile/test_includes.py new file mode 100644 index 0000000..ff01773 --- /dev/null +++ b/tests/compile/test_includes.py @@ -0,0 +1,398 @@ +# Copyright (c) 2015 Uber Technologies, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from __future__ import absolute_import, unicode_literals, print_function + +import time +import pytest + +from thriftrw.loader import Loader +from thriftrw.protocol import BinaryProtocol +from thriftrw.errors import ThriftCompilerError + + +@pytest.fixture +def loader(): + return Loader(BinaryProtocol()) + + +def test_simple_include(tmpdir, loader): + tmpdir.join('types.thrift').write(''' + struct Item { + 1: required string key + 2: required string value + } + ''') + + tmpdir.join('svc.thrift').write(''' + include "types.thrift" + + struct BatchGetResponse { + 1: required list items = [] + } + + service ItemStore { + BatchGetResponse batchGetItems(1: list keys) + } + ''') + + svc = loader.load(str(tmpdir.join('svc.thrift'))) + + # Loading the module we depend on explicitly should give back the same + # generated module. + assert svc.types is loader.load(str(tmpdir.join('types.thrift'))) + assert svc.__includes__ == (svc.types,) + + item = svc.types.Item(key='foo', value='bar') + response = svc.BatchGetResponse([item]) + + assert svc.types.dumps(item) == bytes(bytearray([ + # 1: 'foo' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x66, 0x6f, 0x6f, + + # 2: 'bar' + 0x0B, + 0x00, 0x02, + 0x00, 0x00, 0x00, 0x03, + 0x62, 0x61, 0x72, + + 0x00, + ])) + + svc.dumps(response) == bytes(bytearray([ + # 1: [item] + 0x0F, + 0x00, 0x01, + + # item + 0x0C, + 0x00, 0x00, 0x00, 0x01, + + # 1: 'foo' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x66, 0x6f, 0x6f, + + # 2: 'bar' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x62, 0x61, 0x72, + + 0x00, + ])) + + +def test_include_relative(tmpdir, loader): + tmpdir.join('types/shared.thrift').ensure().write(''' + typedef i64 Timestamp + + exception InternalError { + 1: required string message + } + ''') + + tmpdir.join('team/myservice/myservice.thrift').ensure().write(''' + include "../../types/shared.thrift" + + service Service { + shared.Timestamp getCurrentTime() + throws (1: shared.InternalError internalError) + } + ''') + + myservice = loader.load( + str(tmpdir.join('team/myservice/myservice.thrift')) + ) + + assert myservice.__includes__ == (myservice.shared,) + + myservice.Service.getCurrentTime.response( + success=int(time.time() * 1000) + ) + + myservice.Service.getCurrentTime.response( + internalError=myservice.shared.InternalError('great sadness') + ) + + with pytest.raises(TypeError) as exc_info: + myservice.Service.getCurrentTime.response( + success='2015-10-29T15:00:00Z' + ) + assert 'Cannot serialize' in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + myservice.Service.getCurrentTime.response( + internalError=ZeroDivisionError() + ) + + assert 'Cannot serialize' in str(exc_info) + + +def test_cyclic_includes(tmpdir, loader): + tmpdir.join('node.thrift').write(''' + include "value.thrift" + + struct Node { + 1: required string name + 2: required value.Value value + } + ''') + + tmpdir.join('value.thrift').write(''' + include "node.thrift" + + struct Value { + 1: required list nodes + } + ''') + + node = loader.load(str(tmpdir.join('node.thrift'))) + + assert node.__includes__ == (node.value,) + assert node.value.__includes__ == (node,) + + assert ( + node.dumps(node.Node('hello', node.value.Value([]))) == + bytes(bytearray([ + + # 1: 'hello' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x05, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, # 'hello' + + # 2: {1: []} + 0x0C, + 0x00, 0x02, + + # 1: [] + 0x0F, + 0x00, 0x01, + + # [] + 0x0C, + 0x00, 0x00, 0x00, 0x00, + + 0x00, + + 0x00, + ])) + ) + + +def test_inherit_included_service(tmpdir, loader): + tmpdir.join('common.thrift').write(''' + service BaseService { + string serviceName() + bool healthy() + } + ''') + + tmpdir.join('keyvalue.thrift').write(''' + include "common.thrift" + + service KeyValue extends common.BaseService { + binary get(1: binary key) + void put(1: binary key, 2: binary value) + } + ''') + + keyvalue = loader.load(str(tmpdir.join('keyvalue.thrift'))) + + assert keyvalue.__includes__ == (keyvalue.common,) + + assert issubclass(keyvalue.KeyValue, keyvalue.common.BaseService) + + assert ( + keyvalue.dumps(keyvalue.KeyValue.healthy.response(success=True)) == + bytes(bytearray([0x02, 0x00, 0x00, 0x01, 0x00])) + ) + + +def test_include_constants(tmpdir, loader): + tmpdir.join('bar.thrift').write('const i32 some_num = 42') + + tmpdir.join('foo.thrift').write(''' + include "bar.thrift" + + const list nums = [1, bar.some_num, 2]; + ''') + + foo = loader.load(str(tmpdir.join('foo.thrift'))) + assert foo.nums == [1, 42, 2] == [1, foo.bar.some_num, 2] + + +def test_multi_level_cyclic_import(tmpdir, loader): + + # |- a.thrift + # |- one/ + # |- b.thrift + # |- c.thrift + # |- two/ + # |- d.thrift + + tmpdir.join('a.thrift').write(''' + include "one/b.thrift" + include "one/c.thrift" + ''') + + tmpdir.join('one/b.thrift').ensure().write(''' + include "two/d.thrift" + ''') + + tmpdir.join('one/c.thrift').ensure().write(''' + include "two/d.thrift" + ''') + + tmpdir.join('one/two/d.thrift').ensure().write(''' + include "../../a.thrift" + ''') + + a = loader.load(str(tmpdir.join('a.thrift'))) + assert ( + a.__includes__ == (a.b, a.c) or + a.__includes__ == (a.c, a.b) + ) + + assert a.b.d is a.c.d + + assert a.b.__includes__ == (a.b.d,) + assert a.c.__includes__ == (a.c.d,) + + assert a.b.d.a is a + assert a.c.d.a is a + assert a.b.d.__includes__ == (a,) + assert a.c.d.__includes__ == (a,) + + +@pytest.mark.parametrize('root, data, msgs', [ + ( + # File does not exist + 'foo.thrift', + [('foo.thrift', 'include "bar.thrift"')], + [ + 'Cannot include "bar.thrift"', + 'The file', 'does not exist' + ] + ), + ( + # Two modules in subdirectories with the same name. This should be + # resolved once we implement the "import as" syntax. + 'index.thrift', + [ + ('foo/shared.thrift', 'typedef string timestamp'), + ('bar/shared.thrift', 'typedef string UUID'), + ('index.thrift', ''' + include "foo/shared.thrift" + include "bar/shared.thrift" + '''), + ], + [ + 'Cannot include module "shared"', + 'The name is already taken' + ] + ), + ( + # Unknown type reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "bar.thrift" + + struct Foo { 1: required bar.Bar b } + '''), + ('bar.thrift', ''), + ], + ['Unknown type "Bar" referenced'] + ), + ( + # Unknown service reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "bar.thrift" + + service Foo extends bar.Bar { + } + '''), + ('bar.thrift', 'service NotBar {}'), + ], + ['Unknown service "Bar" referenced'] + ), + ( + # Unknown constant reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "bar.thrift" + + const i32 x = bar.y; + '''), + ('bar.thrift', 'const i32 z = 42'), + ], + ['Unknown constant "y" referenced'] + ), + ( + # Bad type reference + 'foo.thrift', + [('foo.thrift', 'struct Foo { 1: required bar.Bar b }')], + ['Unknown type "bar.Bar" referenced'] + ), + ( + # Bad service reference + 'foo.thrift', + [('foo.thrift', 'service Foo extends bar.Bar {}')], + ['Unknown service "bar.Bar" referenced'] + ), + ( + # Bad constant reference + 'foo.thrift', + [('foo.thrift', 'const i32 x = bar.y')], + ['Unknown constant "bar.y" referenced'] + ), +]) +def test_bad_includes(tmpdir, loader, root, data, msgs): + for path, contents in data: + tmpdir.join(path).ensure().write(contents) + + with pytest.raises(ThriftCompilerError) as exc_info: + loader.load(str(tmpdir.join(root))) + + for msg in msgs: + assert msg in str(exc_info) + + +def test_include_disallowed_with_loads(loads): + with pytest.raises(ThriftCompilerError) as exc_info: + loads(''' + namespace py foo + namespace js bar + + include "foo.thrift" + ''') + + assert ( + 'Includes are not supported when using the "loads()"' in str(exc_info) + ) diff --git a/tests/spec/test_service.py b/tests/spec/test_service.py index fac52c7..596f9a9 100644 --- a/tests/spec/test_service.py +++ b/tests/spec/test_service.py @@ -29,6 +29,7 @@ from thriftrw.spec.service import ServiceSpec from thriftrw.spec.struct import FieldSpec from thriftrw.idl import Parser +from thriftrw.idl.ast import ServiceReference from thriftrw.wire import ttype from ..util.value import vstruct, vbinary, vmap, vbool @@ -117,7 +118,7 @@ def test_compile(parse): ''')) assert spec.name == 'KeyValue' - assert spec.parent == 'BaseService' + assert spec.parent == ServiceReference('BaseService', 2) put_item_spec = spec.functions[0] get_item_spec = spec.functions[1] @@ -161,7 +162,7 @@ def test_link_unknown_parent(loads): with pytest.raises(ThriftCompilerError) as exc_info: loads('service A extends B {}') - assert 'Service "A" inherits from unknown service "B"' in str(exc_info) + assert 'Unknown service "B" referenced at line' in str(exc_info) def test_load(loads): @@ -200,8 +201,8 @@ def test_load(loads): } ''') assert ( - (keyvalue.KeyValue, keyvalue.BaseService) == keyvalue.services or - (keyvalue.BaseService, keyvalue.KeyValue) == keyvalue.services + (keyvalue.KeyValue, keyvalue.BaseService) == keyvalue.__services__ or + (keyvalue.BaseService, keyvalue.KeyValue) == keyvalue.__services__ ) KeyValue = keyvalue.KeyValue diff --git a/tests/test_loader.py b/tests/test_loader.py index 173cdd1..6e00d54 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -82,14 +82,9 @@ def test_caching(tmpdir, monkeypatch): loader = Loader() mod1 = loader.load(path) - assert path in loader.compiled_modules - mod2 = loader.load(path) assert mod1 is mod2 - mod3 = loader.load(path, force=True) - assert mod3 is not mod2 - @pytest.mark.unimport('foo.bar.svc') def test_install_absolute(tmpdir, monkeypatch): diff --git a/thriftrw/_buffer.pyx b/thriftrw/_buffer.pyx index 7479298..8b47c3d 100644 --- a/thriftrw/_buffer.pyx +++ b/thriftrw/_buffer.pyx @@ -152,9 +152,9 @@ cdef class ReadBuffer(object): Number of bytes to read. :returns: ``bytes`` object containing exactly ``count`` bytes. - :raises Exception: + :raises EndOfInputError: If the number of bytes available in the buffer is less than - ``count``. TODO more specific exception + ``count``. """ if count > self.length - self.offset: raise EndOfInputError( diff --git a/thriftrw/compile/__init__.py b/thriftrw/compile/__init__.py index aafeb1c..9b7a3e9 100644 --- a/thriftrw/compile/__init__.py +++ b/thriftrw/compile/__init__.py @@ -24,6 +24,7 @@ from __future__ import absolute_import, unicode_literals, print_function from .compiler import Compiler +from .scope import Scope -__all__ = ['Compiler'] +__all__ = ['Compiler', 'Scope'] diff --git a/thriftrw/compile/compiler.py b/thriftrw/compile/compiler.py index 68de90d..51cf147 100644 --- a/thriftrw/compile/compiler.py +++ b/thriftrw/compile/compiler.py @@ -20,25 +20,113 @@ from __future__ import absolute_import, unicode_literals, print_function +import os.path + +from .scope import Scope from .generate import Generator from .link import TypeSpecLinker from .link import ConstSpecLinker from .link import ServiceSpecLinker -from .scope import Scope from ..errors import ThriftCompilerError +from thriftrw.idl import Parser from thriftrw._runtime import Serializer, Deserializer __all__ = ['Compiler'] +LINKERS = [ConstSpecLinker, TypeSpecLinker, ServiceSpecLinker] + + +class ModuleSpec(object): + """Specification for a single module.""" + + __slots__ = ( + 'name', 'path', 'scope', 'surface', 'includes', 'protocol', 'linked' + ) + + def __init__(self, name, protocol, path=None): + """ + :param name: + Name of the module. + :param path: + Path to the Thrift file from which this module was compiled. This + may be omitted if the module was compiled from an inline string + (using the ``loads()`` API, for example). + """ + + self.name = name + self.path = path + self.protocol = protocol + self.linked = False + self.scope = Scope(name, path) + self.surface = self.scope.module + + # Mapping of names of inculded modules to their corresponding specs. + self.includes = {} + + # TODO Scope can probably be eventually folded into this class. + + @property + def can_include(self): + """Whether this module is allowed to include other modules. + + This is allowed only if the module was compiled from a file since + include paths are relative to the file in which they are mentioned. + """ + return self.path is not None + + def add_include(self, module_spec): + """Adds a module as an included module. + + :param module_spec: + ModuleSpec of the included module. + """ + assert self.can_include + + if module_spec.name in self.includes: + raise ThriftCompilerError( + 'Cannot include module "%s" in "%s". ' + 'The name is already taken.' + % (module_spec.name, self.path) + ) + + self.includes[module_spec.name] = module_spec + self.scope.add_include( + module_spec.name, + module_spec.scope, + module_spec.surface, + ) + + def link(self): + """Link all the types in this module and all included modules.""" + if self.linked: + return self + + self.linked = True + + included_modules = [] + + # Link includes + for include in self.includes.values(): + included_modules.append(include.link().surface) + + self.scope.add_surface('__includes__', tuple(included_modules)) + + # Link self + for linker in LINKERS: + linker(self.scope).link() + + self.scope.add_surface('loads', Deserializer(self.protocol)) + self.scope.add_surface('dumps', Serializer(self.protocol)) + + return self + class Compiler(object): """Compiles IDLs into Python modules.""" - LINKERS = [ConstSpecLinker, TypeSpecLinker, ServiceSpecLinker] - - __slots__ = ('protocol', 'strict') + __slots__ = ('protocol', 'strict', 'parser', '_module_specs') def __init__(self, protocol, strict=True): """Initialize the compiler. @@ -49,26 +137,49 @@ def __init__(self, protocol, strict=True): self.protocol = protocol self.strict = strict - def compile(self, name, program): - """Compile the given parsed Thrift document into a Python module. + self.parser = Parser() + + # Mapping from absolute file path to ModuleSpec for all modules. + self._module_specs = {} + + def compile(self, name, contents, path=None): + """Compile the given Thrift document into a Python module. The generated module contains, - .. py:attribute:: services + .. py:attribute:: __services__ A collection of generated classes for all services defined in the thrift file. - .. py:attribute:: types + .. versionchanged:: 0.6 + + Renamed from ``services`` to ``__services__``. + + .. py:attribute:: __types__ A collection of generated types for all types defined in the thrift file. - .. py:attribute:: constants + .. versionchanged:: 0.6 + + Renamed from ``types`` to ``__types__``. + + .. py:attribute:: __includes__ + + A collection of modules included by this module. + + .. versionadded:: 0.6 + + .. py:attribute:: __constants__ A mapping of constant name to value for all constants defined in the thrift file. + .. versionchanged:: 0.6 + + Renamed from ``constants`` to ``__constants__``. + .. py:function:: dumps(obj) Serializes the given object using the protocol the compiler was @@ -107,41 +218,79 @@ def compile(self, name, program): :py:class:`thriftrw.spec.ServiceFunction` objects for each method defined in the service. - .. versionadded:: 0.2 - The ``constants`` attribute in generated modules. - :param str name: Name of the Thrift document. This will be the name of the generated module. - :param thriftrw.idl.Program program: - AST of the parsted Thrift document. + :param str contents: + Thrift document to compile + :param str path: + Path to the Thrift file being compiled. If not specified, imports + from within the Thrift file will be disallowed. :returns: - The generated module. + ModuleSpec of the generated module. """ - scope = Scope(name) + assert name + + if path: + path = os.path.abspath(path) + if path in self._module_specs: + return self._module_specs[path] + + module_spec = ModuleSpec(name, self.protocol, path) + if path: + self._module_specs[path] = module_spec + + program = self.parser.parse(contents) + header_processor = HeaderProcessor(self, module_spec) for header in program.headers: - header.apply(self) + header.apply(header_processor) - generator = Generator(scope, strict=self.strict) + generator = Generator(module_spec.scope, strict=self.strict) for definition in program.definitions: generator.process(definition) - # TODO Linker can probably just be a callable. - for linker in self.LINKERS: - linker(scope).link() + return module_spec - scope.add_surface('loads', Deserializer(self.protocol)) - scope.add_surface('dumps', Serializer(self.protocol)) - return scope.module +class HeaderProcessor(object): + """Processes headers found in the Thrift file.""" + + __slots__ = ('compiler', 'module_spec') + + def __init__(self, compiler, module_spec): + self.compiler = compiler + self.module_spec = module_spec def visit_include(self, include): - raise ThriftCompilerError( - 'Include of "%s" found on line %d. ' - 'thriftrw does not support including other Thrift files.' - % (include.path, include.lineno) + + if not self.module_spec.can_include: + raise ThriftCompilerError( + 'Include of "%s" found on line %d. ' + 'Includes are not supported when using the "loads()" API.' + 'Try loading the file using the "load()" API.' + % (include.path, include.lineno) + ) + + # Includes are relative to directory of the Thrift file being + # compiled. + path = os.path.join( + os.path.dirname(self.module_spec.path), include.path ) + if not os.path.isfile(path): + raise ThriftCompilerError( + 'Cannot include "%s" on line %d in %s. ' + 'The file "%s" does not exist.' + % (include.path, include.lineno, self.module_spec.path, path) + ) + + name = os.path.splitext(os.path.basename(include.path))[0] + with open(path, 'r') as f: + contents = f.read() + + included_module_spec = self.compiler.compile(name, contents, path) + self.module_spec.add_include(included_module_spec) + def visit_namespace(self, namespace): pass # nothing to do diff --git a/thriftrw/compile/link.py b/thriftrw/compile/link.py index 0e1aa4b..bdfaf12 100644 --- a/thriftrw/compile/link.py +++ b/thriftrw/compile/link.py @@ -47,7 +47,7 @@ def link(self): types.append(type_spec.surface) self.scope.type_specs = type_specs - self.scope.add_surface('types', tuple(types)) + self.scope.add_surface('__types__', tuple(types)) class ServiceSpecLinker(object): @@ -73,7 +73,7 @@ def link(self): services.append(service_spec.surface) self.scope.service_specs = service_specs - self.scope.add_surface('services', tuple(services)) + self.scope.add_surface('__services__', tuple(services)) class ConstSpecLinker(object): @@ -99,4 +99,4 @@ def link(self): constants[const_spec.name] = const_spec.surface self.scope.const_specs = const_specs - self.scope.add_surface('constants', constants) + self.scope.add_surface('__constants__', constants) diff --git a/thriftrw/compile/scope.py b/thriftrw/compile/scope.py index cea4c90..7d0b164 100644 --- a/thriftrw/compile/scope.py +++ b/thriftrw/compile/scope.py @@ -35,17 +35,22 @@ class Scope(object): reference to the final generated module. """ - __slots__ = ('const_specs', 'type_specs', 'module', 'service_specs') + __slots__ = ( + 'const_specs', 'type_specs', 'module', 'service_specs', + 'included_scopes', 'path' + ) - def __init__(self, name): + def __init__(self, name, path=None): """Initialize the scope. :param name: Name of the generated module. """ + self.path = path self.type_specs = {} self.const_specs = {} self.service_specs = {} + self.included_scopes = {} self.module = types.ModuleType(str(name)) @@ -59,21 +64,83 @@ def __str__(self): __repr__ = __str__ + def __in_path(self): + """Helper for error messages to say "in $path" if the scope has a + non-none path. + """ + if self.path: + return ' in "%s"' % self.path + else: + return '' + def resolve_const_spec(self, name, lineno): """Finds and links the ConstSpec with the given name.""" - if name not in self.const_specs: - raise ThriftCompilerError( - 'Unknown constant "%s" referenced at line %d' % (name, lineno) + + if name in self.const_specs: + return self.const_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 1) + if include_name in self.included_scopes: + return self.included_scopes[include_name].resolve_const_spec( + component, lineno + ) + + raise ThriftCompilerError( + 'Unknown constant "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() ) - return self.const_specs[name].link(self) + ) def resolve_type_spec(self, name, lineno): """Finds and links the TypeSpec with the given name.""" - if name not in self.type_specs: - raise ThriftCompilerError( - 'Unknown type "%s" referenced at line %d' % (name, lineno) + + if name in self.type_specs: + return self.type_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 1) + if include_name in self.included_scopes: + return self.included_scopes[include_name].resolve_type_spec( + component, lineno + ) + + raise ThriftCompilerError( + 'Unknown type "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() ) - return self.type_specs[name].link(self) + ) + + def resolve_service_spec(self, name, lineno): + """Finds and links the ServiceSpec with the given name.""" + + if name in self.service_specs: + return self.service_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 2) + if include_name in self.included_scopes: + return self.included_scopes[ + include_name + ].resolve_service_spec(component, lineno) + + raise ThriftCompilerError( + 'Unknown service "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() + ) + ) + + def add_include(self, name, included_scope, module): + """Register an imported module into this scope. + + Raises ``ThriftCompilerError`` if the name has already been used. + """ + # The compiler already ensures this. If we still get here with a + # conflict, that's a bug. + assert name not in self.included_scopes + + self.included_scopes[name] = included_scope + self.add_surface(name, module) def add_service_spec(self, service_spec): """Registers the given ``ServiceSpec`` into the scope. diff --git a/thriftrw/idl/__init__.py b/thriftrw/idl/__init__.py index 5d87371..70fe7a9 100644 --- a/thriftrw/idl/__init__.py +++ b/thriftrw/idl/__init__.py @@ -60,6 +60,8 @@ .. autoclass:: Service +.. autoclass:: ServiceReference + .. autoclass:: Function .. autoclass:: Field @@ -104,6 +106,7 @@ Union, Exc, Service, + ServiceReference, Function, Field, PrimitiveType, @@ -134,6 +137,7 @@ 'Union', 'Exc', 'Service', + 'ServiceReference', 'Function', 'Field', 'PrimitiveType', diff --git a/thriftrw/idl/ast.py b/thriftrw/idl/ast.py index 81b166c..36baed3 100644 --- a/thriftrw/idl/ast.py +++ b/thriftrw/idl/ast.py @@ -34,6 +34,7 @@ 'Union', 'Exc', 'Service', + 'ServiceReference', 'Function', 'Field', 'PrimitiveType', @@ -277,6 +278,15 @@ def apply(self, visitor): # defined by Python. +class ServiceReference(namedtuple('ServiceReference', 'name lineno')): + """A reference to another service. + + .. py:attribute:: name + + Name of the referenced service. + """ + + class Service( namedtuple('Service', 'name functions parent annotations lineno') ): @@ -293,8 +303,8 @@ class Service( .. py:attribute:: parent - Name of the service that this service extends. ``None`` if this service - doesn't have a parent service. + :py:class:`ServiceReference` to the parent service or ``None`` if this + service dosen't have a parent service. .. py:attribute:: annotations diff --git a/thriftrw/idl/parser.py b/thriftrw/idl/parser.py index 4ca347d..7965d58 100644 --- a/thriftrw/idl/parser.py +++ b/thriftrw/idl/parser.py @@ -235,7 +235,7 @@ def p_service(self, p): p[0] = ast.Service( name=p[2], functions=p[6], - parent=p[4], + parent=ast.ServiceReference(p[4], p.lineno(4)), annotations=p[8], lineno=p.lineno(2), ) diff --git a/thriftrw/loader.py b/thriftrw/loader.py index b714b2f..2396350 100644 --- a/thriftrw/loader.py +++ b/thriftrw/loader.py @@ -23,7 +23,6 @@ import inspect import os.path -from .idl import Parser from .compile import Compiler from .protocol import BinaryProtocol @@ -31,7 +30,7 @@ class Loader(object): """Loads and compiles Thrift files.""" - __slots__ = ('parser', 'compiler', 'compiled_modules') + __slots__ = ('compiler',) def __init__(self, protocol=None, strict=True): """Initialize a loader. @@ -49,14 +48,8 @@ def __init__(self, protocol=None, strict=True): compatibility with existing Thrift files. """ protocol = protocol or BinaryProtocol() - - self.parser = Parser() self.compiler = Compiler(protocol, strict=strict) - # Mapping of absolute file path to compiled module. This is used to - # cache the result of calling load() multiple times on the same file. - self.compiled_modules = {} - def loads(self, name, document): """Parse and compile the given Thrift document. @@ -65,38 +58,27 @@ def loads(self, name, document): :param str document: The Thrift IDL as a string. """ - program = self.parser.parse(document) - return self.compiler.compile(name, program) + return self.compiler.compile(name, document).link().surface - def load(self, path, name=None, force=False): + def load(self, path, name=None): """Load and compile the given Thrift file. - If the file was already compiled before, a cached copy of the compiled - module is returned. - :param str path: Path to the ``.thrift`` file. :param str name: Name of the generated module. Defaults to the base name of the file. - :param bool force: - Whether to ignore the cache and load the file anew. Defaults to - False. :returns: The compiled module. """ - path = os.path.abspath(path) - if path in self.compiled_modules and not force: - return self.compiled_modules[path] - if name is None: - # TODO do we care if the file extension is .thrift? name = os.path.splitext(os.path.basename(path))[0] + # TODO do we care if the file extension is .thrift? + with open(path, 'r') as f: document = f.read() - module = self.loads(name, document) - self.compiled_modules[path] = module - return module + + return self.compiler.compile(name, document, path).link().surface _DEFAULT_LOADER = Loader(protocol=BinaryProtocol()) diff --git a/thriftrw/spec/service.pyx b/thriftrw/spec/service.pyx index 5a2a4e8..d959a23 100644 --- a/thriftrw/spec/service.pyx +++ b/thriftrw/spec/service.pyx @@ -357,12 +357,9 @@ class ServiceSpec(object): self.linked = True if self.parent is not None: - if self.parent not in scope.service_specs: - raise ThriftCompilerError( - 'Service "%s" inherits from unknown service "%s"' - % (self.name, self.parent) - ) - self.parent = scope.service_specs[self.parent].link(scope) + self.parent = scope.resolve_service_spec( + self.parent.name, self.parent.lineno + ) self.functions = [func.link(scope) for func in self.functions] self.surface = service_cls(self, scope) From c31bc3d55f76ddeafecc520f4ff76b2b052c37ff Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 30 Oct 2015 16:19:40 -0700 Subject: [PATCH 2/4] Add note about include paths being relative to the current file --- docs/overview.rst | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index c5c40a9..309ed0e 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -448,9 +448,10 @@ Including other Thrift files ---------------------------- Types, services, and constants defined in different Thrift files may be -referenced by using ``include`` statements. Included modules will automatically -be compiled along with the module that included them, and they will be made -available in the generated module with the base name of the included file. +referenced by using ``include`` statements with paths **relative to the current +.thrift file**. Included modules will automatically be compiled along with the +module that included them, and they will be made available in the generated +module with the base name of the included file. For example, given:: @@ -463,9 +464,9 @@ For example, given:: And:: - // service.thrift + // services/user.thrift - include "shared/types.thrift" + include "../shared/types.thrift" struct User { 1: required types.UUID uuid @@ -475,7 +476,7 @@ You can do the following .. code-block:: python - service = thriftrw.load('service.thrift') + service = thriftrw.load('services/user.thrift') user_uuid = service.shared.UUID(...) user = service.User(uuid=user_uuid) @@ -487,7 +488,7 @@ without extra cost because the result is cached by the system. .. code-block:: python - service = thriftrw.load('service.thrift') + service = thriftrw.load('services/user.thrift') types = thriftrw.load('shared/types.thrift') assert service.types is types From 42e1f14c96893ce0639b0bcf125e57168c1537bf Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Mon, 2 Nov 2015 13:05:52 -0800 Subject: [PATCH 3/4] Restrict includes to start with ./ or ../ This clarifies that the includes are relative and ensures that the version without the "./" or "../" is reserved for later if we decide to add a thrift path or something similar. --- docs/overview.rst | 11 ++++-- tests/compile/test_includes.py | 66 +++++++++++++++++++++++++--------- thriftrw/compile/compiler.py | 7 ++++ thriftrw/idl/ast.py | 2 +- 4 files changed, 66 insertions(+), 20 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 309ed0e..f4dd996 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -194,11 +194,15 @@ Includes For an include:: - include "shared.thrift" + include "./shared.thrift" The generated module will include a top-level attribute ``shared`` which references the generated module for ``shared.thrift``. +Note that paths in include statements are relative to the directory containing +the Thrift file and they must be in the from ``./foo.thrift``, +``./foo/bar.thrift``, ``../baz.thrift``, and so on. + Structs ~~~~~~~ @@ -449,7 +453,10 @@ Including other Thrift files Types, services, and constants defined in different Thrift files may be referenced by using ``include`` statements with paths **relative to the current -.thrift file**. Included modules will automatically be compiled along with the +.thrift file**. The paths must be in the form ``./foo.thrift``, +``./foo/bar.thrift``, ``../baz.thrift``, and so on. + +Included modules will automatically be compiled along with the module that included them, and they will be made available in the generated module with the base name of the included file. diff --git a/tests/compile/test_includes.py b/tests/compile/test_includes.py index ff01773..b7bc841 100644 --- a/tests/compile/test_includes.py +++ b/tests/compile/test_includes.py @@ -42,7 +42,7 @@ def test_simple_include(tmpdir, loader): ''') tmpdir.join('svc.thrift').write(''' - include "types.thrift" + include "./types.thrift" struct BatchGetResponse { 1: required list items = [] @@ -152,7 +152,7 @@ def test_include_relative(tmpdir, loader): def test_cyclic_includes(tmpdir, loader): tmpdir.join('node.thrift').write(''' - include "value.thrift" + include "./value.thrift" struct Node { 1: required string name @@ -161,7 +161,7 @@ def test_cyclic_includes(tmpdir, loader): ''') tmpdir.join('value.thrift').write(''' - include "node.thrift" + include "./node.thrift" struct Value { 1: required list nodes @@ -211,7 +211,7 @@ def test_inherit_included_service(tmpdir, loader): ''') tmpdir.join('keyvalue.thrift').write(''' - include "common.thrift" + include "./common.thrift" service KeyValue extends common.BaseService { binary get(1: binary key) @@ -235,7 +235,7 @@ def test_include_constants(tmpdir, loader): tmpdir.join('bar.thrift').write('const i32 some_num = 42') tmpdir.join('foo.thrift').write(''' - include "bar.thrift" + include "./bar.thrift" const list nums = [1, bar.some_num, 2]; ''') @@ -244,6 +244,26 @@ def test_include_constants(tmpdir, loader): assert foo.nums == [1, 42, 2] == [1, foo.bar.some_num, 2] +def test_include_enums(tmpdir, loader): + tmpdir.join('foo.thrift').write(''' + enum Role { + DISABLED = 0, + USER = 1, + MOD = 2, + ADMIN = 3, + } + ''') + + tmpdir.join('bar.thrift').write(''' + include "./foo.thrift" + + const foo.Role DEFAULT_ROLE = foo.Role.USER + ''') + + bar = loader.load(str(tmpdir.join('bar.thrift'))) + assert bar.DEFAULT_ROLE == bar.foo.Role.USER == 1 + + def test_multi_level_cyclic_import(tmpdir, loader): # |- a.thrift @@ -254,16 +274,16 @@ def test_multi_level_cyclic_import(tmpdir, loader): # |- d.thrift tmpdir.join('a.thrift').write(''' - include "one/b.thrift" - include "one/c.thrift" + include "./one/b.thrift" + include "./one/c.thrift" ''') tmpdir.join('one/b.thrift').ensure().write(''' - include "two/d.thrift" + include "./two/d.thrift" ''') tmpdir.join('one/c.thrift').ensure().write(''' - include "two/d.thrift" + include "./two/d.thrift" ''') tmpdir.join('one/two/d.thrift').ensure().write(''' @@ -291,9 +311,9 @@ def test_multi_level_cyclic_import(tmpdir, loader): ( # File does not exist 'foo.thrift', - [('foo.thrift', 'include "bar.thrift"')], + [('foo.thrift', 'include "./bar.thrift"')], [ - 'Cannot include "bar.thrift"', + 'Cannot include "./bar.thrift"', 'The file', 'does not exist' ] ), @@ -305,8 +325,8 @@ def test_multi_level_cyclic_import(tmpdir, loader): ('foo/shared.thrift', 'typedef string timestamp'), ('bar/shared.thrift', 'typedef string UUID'), ('index.thrift', ''' - include "foo/shared.thrift" - include "bar/shared.thrift" + include "./foo/shared.thrift" + include "./bar/shared.thrift" '''), ], [ @@ -319,7 +339,7 @@ def test_multi_level_cyclic_import(tmpdir, loader): 'foo.thrift', [ ('foo.thrift', ''' - include "bar.thrift" + include "./bar.thrift" struct Foo { 1: required bar.Bar b } '''), @@ -332,7 +352,7 @@ def test_multi_level_cyclic_import(tmpdir, loader): 'foo.thrift', [ ('foo.thrift', ''' - include "bar.thrift" + include "./bar.thrift" service Foo extends bar.Bar { } @@ -346,7 +366,7 @@ def test_multi_level_cyclic_import(tmpdir, loader): 'foo.thrift', [ ('foo.thrift', ''' - include "bar.thrift" + include "./bar.thrift" const i32 x = bar.y; '''), @@ -372,6 +392,18 @@ def test_multi_level_cyclic_import(tmpdir, loader): [('foo.thrift', 'const i32 x = bar.y')], ['Unknown constant "bar.y" referenced'] ), + ( + # Include path that doesn't start with '.' + 'foo.thrift', + [ + ('foo.thrift', 'include "bar.thrift"'), + ('bar.thrift', 'const i32 x = 42'), + ], + [ + 'Paths in include statements are relative', + 'must be in the form "./foo.thrift"' + ] + ), ]) def test_bad_includes(tmpdir, loader, root, data, msgs): for path, contents in data: @@ -390,7 +422,7 @@ def test_include_disallowed_with_loads(loads): namespace py foo namespace js bar - include "foo.thrift" + include "./foo.thrift" ''') assert ( diff --git a/thriftrw/compile/compiler.py b/thriftrw/compile/compiler.py index 51cf147..d137aa4 100644 --- a/thriftrw/compile/compiler.py +++ b/thriftrw/compile/compiler.py @@ -272,6 +272,13 @@ def visit_include(self, include): % (include.path, include.lineno) ) + if not any(include.path.startswith(p) for p in ('./', '../')): + raise ThriftCompilerError( + 'Paths in include statements are relative to the directory ' + 'containing the Thrift file. They must be in the form ' + '"./foo.thrift" or "../bar.thrift".' + ) + # Includes are relative to directory of the Thrift file being # compiled. path = os.path.join( diff --git a/thriftrw/idl/ast.py b/thriftrw/idl/ast.py index 36baed3..9873491 100644 --- a/thriftrw/idl/ast.py +++ b/thriftrw/idl/ast.py @@ -87,7 +87,7 @@ class Include(namedtuple('Include', 'path lineno')): :: - include "common.thrift" + include "./common.thrift" .. py:attribute:: path From 432a28ccb510610ab3d15cdfbe7b7d196bb496b2 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Tue, 3 Nov 2015 09:27:03 -0800 Subject: [PATCH 4/4] Makefile: rename bootstrap to install --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 66c6dfd..42713fe 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test lint docs docsopen clean bootstrap +.PHONY: test lint docs docsopen clean install test_args := \ --cov thriftrw \ @@ -30,7 +30,7 @@ clean: find tests thriftrw -name \*.so -delete make -C docs clean -bootstrap: +install: pip install -r requirements.txt pip install -r requirements-dev.txt pip install -r requirements-test.txt