diff --git a/.gitignore b/.gitignore index c3d48cb..4c6d916 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ env/ _build *.swp node_modules +*.c +*.so +.benchmarks/ +.coverage diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5a08659..c9bdab4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Releases 0.6.0 (unreleased) ------------------ +- ``include`` statements are now supported. For more information, see + :ref:`including-modules`. - Added support for message envelopes. This makes it possible to talk with standard Apache Thrift services and clients. For more information, see :ref:`calling-apache-thrift`. @@ -19,6 +21,9 @@ Releases ``FunctionSpec``. - ``ServiceSpec`` now provides a ``lookup`` method to look up ``FunctionSpecs`` by name. +- Removed the ``force`` option on ``Loader.load``. +- In generated modules, renamed the ``types``, ``constants`` and ``services`` + attributes to ``__types__``, ``__constants__``, and ``__services__``. 0.5.2 (2015-10-19) diff --git a/Makefile b/Makefile index 48d0364..42713fe 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test lint docs docsopen clean +.PHONY: test lint docs docsopen clean install test_args := \ --cov thriftrw \ @@ -29,3 +29,9 @@ clean: find tests thriftrw -name \*.c -delete find tests thriftrw -name \*.so -delete make -C docs clean + +install: + pip install -r requirements.txt + pip install -r requirements-dev.txt + pip install -r requirements-test.txt + pip install -e . diff --git a/docs/overview.rst b/docs/overview.rst index 4ca6b4c..f4dd996 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -155,21 +155,54 @@ The generated module contains two top-level functions ``dumps`` and ``loads``. If the method name was not recognized or any other payload parsing errors. -.. py:attribute:: services +.. py:attribute:: __services__ Collection of all classes generated for all services defined in the source thrift file. -.. py:attribute:: types + .. versionchanged:: 0.6 + + Renamed from ``services`` to ``__services__``. + +.. py:attribute:: __types__ Collection of all classes generated for all types defined in the source thrift file. -.. py:attribute:: constants + .. versionchanged:: 0.6 + + Renamed from ``types`` to ``__types__``. + +.. py:attribute:: __includes__ + + Collection of modules which were referenced via ``include`` statements in + the generated module. + + .. versionadded:: 0.6 + +.. py:attribute:: __constants__ Mapping of constant name to value for all constants defined in the source thrift file. + .. versionchanged:: 0.6 + + Renamed from ``constants`` to ``__constants__``. + +Includes +~~~~~~~~ + +For an include:: + + include "./shared.thrift" + +The generated module will include a top-level attribute ``shared`` which +references the generated module for ``shared.thrift``. + +Note that paths in include statements are relative to the directory containing +the Thrift file and they must be in the from ``./foo.thrift``, +``./foo/bar.thrift``, ``../baz.thrift``, and so on. + Structs ~~~~~~~ @@ -413,6 +446,60 @@ Thrift Type Primitive Type ``exception`` ``dict`` ============= ============== +.. _including-modules: + +Including other Thrift files +---------------------------- + +Types, services, and constants defined in different Thrift files may be +referenced by using ``include`` statements with paths **relative to the current +.thrift file**. The paths must be in the form ``./foo.thrift``, +``./foo/bar.thrift``, ``../baz.thrift``, and so on. + +Included modules will automatically be compiled along with the +module that included them, and they will be made available in the generated +module with the base name of the included file. + +For example, given:: + + // shared/types.thrift + + struct UUID { + 1: required i64 high + 2: required i64 low + } + +And:: + + // services/user.thrift + + include "../shared/types.thrift" + + struct User { + 1: required types.UUID uuid + } + +You can do the following + +.. code-block:: python + + service = thriftrw.load('services/user.thrift') + + user_uuid = service.shared.UUID(...) + user = service.User(uuid=user_uuid) + + # ... + +Also note that you can ``load()`` Thrift files that have already been loaded +without extra cost because the result is cached by the system. + +.. code-block:: python + + service = thriftrw.load('services/user.thrift') + types = thriftrw.load('shared/types.thrift') + + assert service.types is types + .. _calling-apache-thrift: Calling Apache Thrift diff --git a/tests/compile/test_compiler.py b/tests/compile/test_compiler.py index 1f791bc..93c2622 100644 --- a/tests/compile/test_compiler.py +++ b/tests/compile/test_compiler.py @@ -21,41 +21,11 @@ from __future__ import absolute_import, unicode_literals, print_function import pytest -from functools import partial -from thriftrw.idl import Parser -from thriftrw.protocol import BinaryProtocol -from thriftrw.compile import Compiler from thriftrw.errors import ThriftCompilerError @pytest.fixture -def parse(): - return Parser().parse - - -@pytest.fixture -def compile(request): - return partial(Compiler(BinaryProtocol()).compile, request.node.name) - - -@pytest.fixture -def loads(parse, compile): - return (lambda s: compile(parse(s))) - - -def test_include_disallowed(loads): - with pytest.raises(ThriftCompilerError) as exc_info: - loads(''' - namespace py foo - namespace js bar - - include "foo.thrift" - ''') - - assert 'thriftrw does not support including' in str(exc_info) - - def test_unknown_type(loads): with pytest.raises(ThriftCompilerError) as exc_info: loads(''' @@ -118,14 +88,14 @@ def test_services_and_types(loads): 'z': [m.x, m.y], 'x': 42, 'y': 123, - } == m.constants + } == m.__constants__ assert ( - m.types == (m.Foo, m.Bar) or - m.types == (m.Bar, m.Foo) + m.__types__ == (m.Foo, m.Bar) or + m.__types__ == (m.Bar, m.Foo) ) assert ( - m.services == (m.A, m.B) or - m.services == (m.B, m.A) + m.__services__ == (m.A, m.B) or + m.__services__ == (m.B, m.A) ) diff --git a/tests/compile/test_includes.py b/tests/compile/test_includes.py new file mode 100644 index 0000000..b7bc841 --- /dev/null +++ b/tests/compile/test_includes.py @@ -0,0 +1,430 @@ +# Copyright (c) 2015 Uber Technologies, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from __future__ import absolute_import, unicode_literals, print_function + +import time +import pytest + +from thriftrw.loader import Loader +from thriftrw.protocol import BinaryProtocol +from thriftrw.errors import ThriftCompilerError + + +@pytest.fixture +def loader(): + return Loader(BinaryProtocol()) + + +def test_simple_include(tmpdir, loader): + tmpdir.join('types.thrift').write(''' + struct Item { + 1: required string key + 2: required string value + } + ''') + + tmpdir.join('svc.thrift').write(''' + include "./types.thrift" + + struct BatchGetResponse { + 1: required list items = [] + } + + service ItemStore { + BatchGetResponse batchGetItems(1: list keys) + } + ''') + + svc = loader.load(str(tmpdir.join('svc.thrift'))) + + # Loading the module we depend on explicitly should give back the same + # generated module. + assert svc.types is loader.load(str(tmpdir.join('types.thrift'))) + assert svc.__includes__ == (svc.types,) + + item = svc.types.Item(key='foo', value='bar') + response = svc.BatchGetResponse([item]) + + assert svc.types.dumps(item) == bytes(bytearray([ + # 1: 'foo' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x66, 0x6f, 0x6f, + + # 2: 'bar' + 0x0B, + 0x00, 0x02, + 0x00, 0x00, 0x00, 0x03, + 0x62, 0x61, 0x72, + + 0x00, + ])) + + svc.dumps(response) == bytes(bytearray([ + # 1: [item] + 0x0F, + 0x00, 0x01, + + # item + 0x0C, + 0x00, 0x00, 0x00, 0x01, + + # 1: 'foo' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x66, 0x6f, 0x6f, + + # 2: 'bar' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, + 0x62, 0x61, 0x72, + + 0x00, + ])) + + +def test_include_relative(tmpdir, loader): + tmpdir.join('types/shared.thrift').ensure().write(''' + typedef i64 Timestamp + + exception InternalError { + 1: required string message + } + ''') + + tmpdir.join('team/myservice/myservice.thrift').ensure().write(''' + include "../../types/shared.thrift" + + service Service { + shared.Timestamp getCurrentTime() + throws (1: shared.InternalError internalError) + } + ''') + + myservice = loader.load( + str(tmpdir.join('team/myservice/myservice.thrift')) + ) + + assert myservice.__includes__ == (myservice.shared,) + + myservice.Service.getCurrentTime.response( + success=int(time.time() * 1000) + ) + + myservice.Service.getCurrentTime.response( + internalError=myservice.shared.InternalError('great sadness') + ) + + with pytest.raises(TypeError) as exc_info: + myservice.Service.getCurrentTime.response( + success='2015-10-29T15:00:00Z' + ) + assert 'Cannot serialize' in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + myservice.Service.getCurrentTime.response( + internalError=ZeroDivisionError() + ) + + assert 'Cannot serialize' in str(exc_info) + + +def test_cyclic_includes(tmpdir, loader): + tmpdir.join('node.thrift').write(''' + include "./value.thrift" + + struct Node { + 1: required string name + 2: required value.Value value + } + ''') + + tmpdir.join('value.thrift').write(''' + include "./node.thrift" + + struct Value { + 1: required list nodes + } + ''') + + node = loader.load(str(tmpdir.join('node.thrift'))) + + assert node.__includes__ == (node.value,) + assert node.value.__includes__ == (node,) + + assert ( + node.dumps(node.Node('hello', node.value.Value([]))) == + bytes(bytearray([ + + # 1: 'hello' + 0x0B, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x05, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, # 'hello' + + # 2: {1: []} + 0x0C, + 0x00, 0x02, + + # 1: [] + 0x0F, + 0x00, 0x01, + + # [] + 0x0C, + 0x00, 0x00, 0x00, 0x00, + + 0x00, + + 0x00, + ])) + ) + + +def test_inherit_included_service(tmpdir, loader): + tmpdir.join('common.thrift').write(''' + service BaseService { + string serviceName() + bool healthy() + } + ''') + + tmpdir.join('keyvalue.thrift').write(''' + include "./common.thrift" + + service KeyValue extends common.BaseService { + binary get(1: binary key) + void put(1: binary key, 2: binary value) + } + ''') + + keyvalue = loader.load(str(tmpdir.join('keyvalue.thrift'))) + + assert keyvalue.__includes__ == (keyvalue.common,) + + assert issubclass(keyvalue.KeyValue, keyvalue.common.BaseService) + + assert ( + keyvalue.dumps(keyvalue.KeyValue.healthy.response(success=True)) == + bytes(bytearray([0x02, 0x00, 0x00, 0x01, 0x00])) + ) + + +def test_include_constants(tmpdir, loader): + tmpdir.join('bar.thrift').write('const i32 some_num = 42') + + tmpdir.join('foo.thrift').write(''' + include "./bar.thrift" + + const list nums = [1, bar.some_num, 2]; + ''') + + foo = loader.load(str(tmpdir.join('foo.thrift'))) + assert foo.nums == [1, 42, 2] == [1, foo.bar.some_num, 2] + + +def test_include_enums(tmpdir, loader): + tmpdir.join('foo.thrift').write(''' + enum Role { + DISABLED = 0, + USER = 1, + MOD = 2, + ADMIN = 3, + } + ''') + + tmpdir.join('bar.thrift').write(''' + include "./foo.thrift" + + const foo.Role DEFAULT_ROLE = foo.Role.USER + ''') + + bar = loader.load(str(tmpdir.join('bar.thrift'))) + assert bar.DEFAULT_ROLE == bar.foo.Role.USER == 1 + + +def test_multi_level_cyclic_import(tmpdir, loader): + + # |- a.thrift + # |- one/ + # |- b.thrift + # |- c.thrift + # |- two/ + # |- d.thrift + + tmpdir.join('a.thrift').write(''' + include "./one/b.thrift" + include "./one/c.thrift" + ''') + + tmpdir.join('one/b.thrift').ensure().write(''' + include "./two/d.thrift" + ''') + + tmpdir.join('one/c.thrift').ensure().write(''' + include "./two/d.thrift" + ''') + + tmpdir.join('one/two/d.thrift').ensure().write(''' + include "../../a.thrift" + ''') + + a = loader.load(str(tmpdir.join('a.thrift'))) + assert ( + a.__includes__ == (a.b, a.c) or + a.__includes__ == (a.c, a.b) + ) + + assert a.b.d is a.c.d + + assert a.b.__includes__ == (a.b.d,) + assert a.c.__includes__ == (a.c.d,) + + assert a.b.d.a is a + assert a.c.d.a is a + assert a.b.d.__includes__ == (a,) + assert a.c.d.__includes__ == (a,) + + +@pytest.mark.parametrize('root, data, msgs', [ + ( + # File does not exist + 'foo.thrift', + [('foo.thrift', 'include "./bar.thrift"')], + [ + 'Cannot include "./bar.thrift"', + 'The file', 'does not exist' + ] + ), + ( + # Two modules in subdirectories with the same name. This should be + # resolved once we implement the "import as" syntax. + 'index.thrift', + [ + ('foo/shared.thrift', 'typedef string timestamp'), + ('bar/shared.thrift', 'typedef string UUID'), + ('index.thrift', ''' + include "./foo/shared.thrift" + include "./bar/shared.thrift" + '''), + ], + [ + 'Cannot include module "shared"', + 'The name is already taken' + ] + ), + ( + # Unknown type reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "./bar.thrift" + + struct Foo { 1: required bar.Bar b } + '''), + ('bar.thrift', ''), + ], + ['Unknown type "Bar" referenced'] + ), + ( + # Unknown service reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "./bar.thrift" + + service Foo extends bar.Bar { + } + '''), + ('bar.thrift', 'service NotBar {}'), + ], + ['Unknown service "Bar" referenced'] + ), + ( + # Unknown constant reference + 'foo.thrift', + [ + ('foo.thrift', ''' + include "./bar.thrift" + + const i32 x = bar.y; + '''), + ('bar.thrift', 'const i32 z = 42'), + ], + ['Unknown constant "y" referenced'] + ), + ( + # Bad type reference + 'foo.thrift', + [('foo.thrift', 'struct Foo { 1: required bar.Bar b }')], + ['Unknown type "bar.Bar" referenced'] + ), + ( + # Bad service reference + 'foo.thrift', + [('foo.thrift', 'service Foo extends bar.Bar {}')], + ['Unknown service "bar.Bar" referenced'] + ), + ( + # Bad constant reference + 'foo.thrift', + [('foo.thrift', 'const i32 x = bar.y')], + ['Unknown constant "bar.y" referenced'] + ), + ( + # Include path that doesn't start with '.' + 'foo.thrift', + [ + ('foo.thrift', 'include "bar.thrift"'), + ('bar.thrift', 'const i32 x = 42'), + ], + [ + 'Paths in include statements are relative', + 'must be in the form "./foo.thrift"' + ] + ), +]) +def test_bad_includes(tmpdir, loader, root, data, msgs): + for path, contents in data: + tmpdir.join(path).ensure().write(contents) + + with pytest.raises(ThriftCompilerError) as exc_info: + loader.load(str(tmpdir.join(root))) + + for msg in msgs: + assert msg in str(exc_info) + + +def test_include_disallowed_with_loads(loads): + with pytest.raises(ThriftCompilerError) as exc_info: + loads(''' + namespace py foo + namespace js bar + + include "./foo.thrift" + ''') + + assert ( + 'Includes are not supported when using the "loads()"' in str(exc_info) + ) diff --git a/tests/spec/test_service.py b/tests/spec/test_service.py index fac52c7..596f9a9 100644 --- a/tests/spec/test_service.py +++ b/tests/spec/test_service.py @@ -29,6 +29,7 @@ from thriftrw.spec.service import ServiceSpec from thriftrw.spec.struct import FieldSpec from thriftrw.idl import Parser +from thriftrw.idl.ast import ServiceReference from thriftrw.wire import ttype from ..util.value import vstruct, vbinary, vmap, vbool @@ -117,7 +118,7 @@ def test_compile(parse): ''')) assert spec.name == 'KeyValue' - assert spec.parent == 'BaseService' + assert spec.parent == ServiceReference('BaseService', 2) put_item_spec = spec.functions[0] get_item_spec = spec.functions[1] @@ -161,7 +162,7 @@ def test_link_unknown_parent(loads): with pytest.raises(ThriftCompilerError) as exc_info: loads('service A extends B {}') - assert 'Service "A" inherits from unknown service "B"' in str(exc_info) + assert 'Unknown service "B" referenced at line' in str(exc_info) def test_load(loads): @@ -200,8 +201,8 @@ def test_load(loads): } ''') assert ( - (keyvalue.KeyValue, keyvalue.BaseService) == keyvalue.services or - (keyvalue.BaseService, keyvalue.KeyValue) == keyvalue.services + (keyvalue.KeyValue, keyvalue.BaseService) == keyvalue.__services__ or + (keyvalue.BaseService, keyvalue.KeyValue) == keyvalue.__services__ ) KeyValue = keyvalue.KeyValue diff --git a/tests/test_loader.py b/tests/test_loader.py index 173cdd1..6e00d54 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -82,14 +82,9 @@ def test_caching(tmpdir, monkeypatch): loader = Loader() mod1 = loader.load(path) - assert path in loader.compiled_modules - mod2 = loader.load(path) assert mod1 is mod2 - mod3 = loader.load(path, force=True) - assert mod3 is not mod2 - @pytest.mark.unimport('foo.bar.svc') def test_install_absolute(tmpdir, monkeypatch): diff --git a/thriftrw/_buffer.pyx b/thriftrw/_buffer.pyx index 7479298..8b47c3d 100644 --- a/thriftrw/_buffer.pyx +++ b/thriftrw/_buffer.pyx @@ -152,9 +152,9 @@ cdef class ReadBuffer(object): Number of bytes to read. :returns: ``bytes`` object containing exactly ``count`` bytes. - :raises Exception: + :raises EndOfInputError: If the number of bytes available in the buffer is less than - ``count``. TODO more specific exception + ``count``. """ if count > self.length - self.offset: raise EndOfInputError( diff --git a/thriftrw/compile/__init__.py b/thriftrw/compile/__init__.py index aafeb1c..9b7a3e9 100644 --- a/thriftrw/compile/__init__.py +++ b/thriftrw/compile/__init__.py @@ -24,6 +24,7 @@ from __future__ import absolute_import, unicode_literals, print_function from .compiler import Compiler +from .scope import Scope -__all__ = ['Compiler'] +__all__ = ['Compiler', 'Scope'] diff --git a/thriftrw/compile/compiler.py b/thriftrw/compile/compiler.py index 68de90d..d137aa4 100644 --- a/thriftrw/compile/compiler.py +++ b/thriftrw/compile/compiler.py @@ -20,25 +20,113 @@ from __future__ import absolute_import, unicode_literals, print_function +import os.path + +from .scope import Scope from .generate import Generator from .link import TypeSpecLinker from .link import ConstSpecLinker from .link import ServiceSpecLinker -from .scope import Scope from ..errors import ThriftCompilerError +from thriftrw.idl import Parser from thriftrw._runtime import Serializer, Deserializer __all__ = ['Compiler'] +LINKERS = [ConstSpecLinker, TypeSpecLinker, ServiceSpecLinker] + + +class ModuleSpec(object): + """Specification for a single module.""" + + __slots__ = ( + 'name', 'path', 'scope', 'surface', 'includes', 'protocol', 'linked' + ) + + def __init__(self, name, protocol, path=None): + """ + :param name: + Name of the module. + :param path: + Path to the Thrift file from which this module was compiled. This + may be omitted if the module was compiled from an inline string + (using the ``loads()`` API, for example). + """ + + self.name = name + self.path = path + self.protocol = protocol + self.linked = False + self.scope = Scope(name, path) + self.surface = self.scope.module + + # Mapping of names of inculded modules to their corresponding specs. + self.includes = {} + + # TODO Scope can probably be eventually folded into this class. + + @property + def can_include(self): + """Whether this module is allowed to include other modules. + + This is allowed only if the module was compiled from a file since + include paths are relative to the file in which they are mentioned. + """ + return self.path is not None + + def add_include(self, module_spec): + """Adds a module as an included module. + + :param module_spec: + ModuleSpec of the included module. + """ + assert self.can_include + + if module_spec.name in self.includes: + raise ThriftCompilerError( + 'Cannot include module "%s" in "%s". ' + 'The name is already taken.' + % (module_spec.name, self.path) + ) + + self.includes[module_spec.name] = module_spec + self.scope.add_include( + module_spec.name, + module_spec.scope, + module_spec.surface, + ) + + def link(self): + """Link all the types in this module and all included modules.""" + if self.linked: + return self + + self.linked = True + + included_modules = [] + + # Link includes + for include in self.includes.values(): + included_modules.append(include.link().surface) + + self.scope.add_surface('__includes__', tuple(included_modules)) + + # Link self + for linker in LINKERS: + linker(self.scope).link() + + self.scope.add_surface('loads', Deserializer(self.protocol)) + self.scope.add_surface('dumps', Serializer(self.protocol)) + + return self + class Compiler(object): """Compiles IDLs into Python modules.""" - LINKERS = [ConstSpecLinker, TypeSpecLinker, ServiceSpecLinker] - - __slots__ = ('protocol', 'strict') + __slots__ = ('protocol', 'strict', 'parser', '_module_specs') def __init__(self, protocol, strict=True): """Initialize the compiler. @@ -49,26 +137,49 @@ def __init__(self, protocol, strict=True): self.protocol = protocol self.strict = strict - def compile(self, name, program): - """Compile the given parsed Thrift document into a Python module. + self.parser = Parser() + + # Mapping from absolute file path to ModuleSpec for all modules. + self._module_specs = {} + + def compile(self, name, contents, path=None): + """Compile the given Thrift document into a Python module. The generated module contains, - .. py:attribute:: services + .. py:attribute:: __services__ A collection of generated classes for all services defined in the thrift file. - .. py:attribute:: types + .. versionchanged:: 0.6 + + Renamed from ``services`` to ``__services__``. + + .. py:attribute:: __types__ A collection of generated types for all types defined in the thrift file. - .. py:attribute:: constants + .. versionchanged:: 0.6 + + Renamed from ``types`` to ``__types__``. + + .. py:attribute:: __includes__ + + A collection of modules included by this module. + + .. versionadded:: 0.6 + + .. py:attribute:: __constants__ A mapping of constant name to value for all constants defined in the thrift file. + .. versionchanged:: 0.6 + + Renamed from ``constants`` to ``__constants__``. + .. py:function:: dumps(obj) Serializes the given object using the protocol the compiler was @@ -107,41 +218,86 @@ def compile(self, name, program): :py:class:`thriftrw.spec.ServiceFunction` objects for each method defined in the service. - .. versionadded:: 0.2 - The ``constants`` attribute in generated modules. - :param str name: Name of the Thrift document. This will be the name of the generated module. - :param thriftrw.idl.Program program: - AST of the parsted Thrift document. + :param str contents: + Thrift document to compile + :param str path: + Path to the Thrift file being compiled. If not specified, imports + from within the Thrift file will be disallowed. :returns: - The generated module. + ModuleSpec of the generated module. """ - scope = Scope(name) + assert name + + if path: + path = os.path.abspath(path) + if path in self._module_specs: + return self._module_specs[path] + + module_spec = ModuleSpec(name, self.protocol, path) + if path: + self._module_specs[path] = module_spec + + program = self.parser.parse(contents) + header_processor = HeaderProcessor(self, module_spec) for header in program.headers: - header.apply(self) + header.apply(header_processor) - generator = Generator(scope, strict=self.strict) + generator = Generator(module_spec.scope, strict=self.strict) for definition in program.definitions: generator.process(definition) - # TODO Linker can probably just be a callable. - for linker in self.LINKERS: - linker(scope).link() + return module_spec - scope.add_surface('loads', Deserializer(self.protocol)) - scope.add_surface('dumps', Serializer(self.protocol)) - return scope.module +class HeaderProcessor(object): + """Processes headers found in the Thrift file.""" + + __slots__ = ('compiler', 'module_spec') + + def __init__(self, compiler, module_spec): + self.compiler = compiler + self.module_spec = module_spec def visit_include(self, include): - raise ThriftCompilerError( - 'Include of "%s" found on line %d. ' - 'thriftrw does not support including other Thrift files.' - % (include.path, include.lineno) + + if not self.module_spec.can_include: + raise ThriftCompilerError( + 'Include of "%s" found on line %d. ' + 'Includes are not supported when using the "loads()" API.' + 'Try loading the file using the "load()" API.' + % (include.path, include.lineno) + ) + + if not any(include.path.startswith(p) for p in ('./', '../')): + raise ThriftCompilerError( + 'Paths in include statements are relative to the directory ' + 'containing the Thrift file. They must be in the form ' + '"./foo.thrift" or "../bar.thrift".' + ) + + # Includes are relative to directory of the Thrift file being + # compiled. + path = os.path.join( + os.path.dirname(self.module_spec.path), include.path ) + if not os.path.isfile(path): + raise ThriftCompilerError( + 'Cannot include "%s" on line %d in %s. ' + 'The file "%s" does not exist.' + % (include.path, include.lineno, self.module_spec.path, path) + ) + + name = os.path.splitext(os.path.basename(include.path))[0] + with open(path, 'r') as f: + contents = f.read() + + included_module_spec = self.compiler.compile(name, contents, path) + self.module_spec.add_include(included_module_spec) + def visit_namespace(self, namespace): pass # nothing to do diff --git a/thriftrw/compile/link.py b/thriftrw/compile/link.py index 0e1aa4b..bdfaf12 100644 --- a/thriftrw/compile/link.py +++ b/thriftrw/compile/link.py @@ -47,7 +47,7 @@ def link(self): types.append(type_spec.surface) self.scope.type_specs = type_specs - self.scope.add_surface('types', tuple(types)) + self.scope.add_surface('__types__', tuple(types)) class ServiceSpecLinker(object): @@ -73,7 +73,7 @@ def link(self): services.append(service_spec.surface) self.scope.service_specs = service_specs - self.scope.add_surface('services', tuple(services)) + self.scope.add_surface('__services__', tuple(services)) class ConstSpecLinker(object): @@ -99,4 +99,4 @@ def link(self): constants[const_spec.name] = const_spec.surface self.scope.const_specs = const_specs - self.scope.add_surface('constants', constants) + self.scope.add_surface('__constants__', constants) diff --git a/thriftrw/compile/scope.py b/thriftrw/compile/scope.py index cea4c90..7d0b164 100644 --- a/thriftrw/compile/scope.py +++ b/thriftrw/compile/scope.py @@ -35,17 +35,22 @@ class Scope(object): reference to the final generated module. """ - __slots__ = ('const_specs', 'type_specs', 'module', 'service_specs') + __slots__ = ( + 'const_specs', 'type_specs', 'module', 'service_specs', + 'included_scopes', 'path' + ) - def __init__(self, name): + def __init__(self, name, path=None): """Initialize the scope. :param name: Name of the generated module. """ + self.path = path self.type_specs = {} self.const_specs = {} self.service_specs = {} + self.included_scopes = {} self.module = types.ModuleType(str(name)) @@ -59,21 +64,83 @@ def __str__(self): __repr__ = __str__ + def __in_path(self): + """Helper for error messages to say "in $path" if the scope has a + non-none path. + """ + if self.path: + return ' in "%s"' % self.path + else: + return '' + def resolve_const_spec(self, name, lineno): """Finds and links the ConstSpec with the given name.""" - if name not in self.const_specs: - raise ThriftCompilerError( - 'Unknown constant "%s" referenced at line %d' % (name, lineno) + + if name in self.const_specs: + return self.const_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 1) + if include_name in self.included_scopes: + return self.included_scopes[include_name].resolve_const_spec( + component, lineno + ) + + raise ThriftCompilerError( + 'Unknown constant "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() ) - return self.const_specs[name].link(self) + ) def resolve_type_spec(self, name, lineno): """Finds and links the TypeSpec with the given name.""" - if name not in self.type_specs: - raise ThriftCompilerError( - 'Unknown type "%s" referenced at line %d' % (name, lineno) + + if name in self.type_specs: + return self.type_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 1) + if include_name in self.included_scopes: + return self.included_scopes[include_name].resolve_type_spec( + component, lineno + ) + + raise ThriftCompilerError( + 'Unknown type "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() ) - return self.type_specs[name].link(self) + ) + + def resolve_service_spec(self, name, lineno): + """Finds and links the ServiceSpec with the given name.""" + + if name in self.service_specs: + return self.service_specs[name].link(self) + + if '.' in name: + include_name, component = name.split('.', 2) + if include_name in self.included_scopes: + return self.included_scopes[ + include_name + ].resolve_service_spec(component, lineno) + + raise ThriftCompilerError( + 'Unknown service "%s" referenced at line %d%s' % ( + name, lineno, self.__in_path() + ) + ) + + def add_include(self, name, included_scope, module): + """Register an imported module into this scope. + + Raises ``ThriftCompilerError`` if the name has already been used. + """ + # The compiler already ensures this. If we still get here with a + # conflict, that's a bug. + assert name not in self.included_scopes + + self.included_scopes[name] = included_scope + self.add_surface(name, module) def add_service_spec(self, service_spec): """Registers the given ``ServiceSpec`` into the scope. diff --git a/thriftrw/idl/__init__.py b/thriftrw/idl/__init__.py index 5d87371..70fe7a9 100644 --- a/thriftrw/idl/__init__.py +++ b/thriftrw/idl/__init__.py @@ -60,6 +60,8 @@ .. autoclass:: Service +.. autoclass:: ServiceReference + .. autoclass:: Function .. autoclass:: Field @@ -104,6 +106,7 @@ Union, Exc, Service, + ServiceReference, Function, Field, PrimitiveType, @@ -134,6 +137,7 @@ 'Union', 'Exc', 'Service', + 'ServiceReference', 'Function', 'Field', 'PrimitiveType', diff --git a/thriftrw/idl/ast.py b/thriftrw/idl/ast.py index 81b166c..9873491 100644 --- a/thriftrw/idl/ast.py +++ b/thriftrw/idl/ast.py @@ -34,6 +34,7 @@ 'Union', 'Exc', 'Service', + 'ServiceReference', 'Function', 'Field', 'PrimitiveType', @@ -86,7 +87,7 @@ class Include(namedtuple('Include', 'path lineno')): :: - include "common.thrift" + include "./common.thrift" .. py:attribute:: path @@ -277,6 +278,15 @@ def apply(self, visitor): # defined by Python. +class ServiceReference(namedtuple('ServiceReference', 'name lineno')): + """A reference to another service. + + .. py:attribute:: name + + Name of the referenced service. + """ + + class Service( namedtuple('Service', 'name functions parent annotations lineno') ): @@ -293,8 +303,8 @@ class Service( .. py:attribute:: parent - Name of the service that this service extends. ``None`` if this service - doesn't have a parent service. + :py:class:`ServiceReference` to the parent service or ``None`` if this + service dosen't have a parent service. .. py:attribute:: annotations diff --git a/thriftrw/idl/parser.py b/thriftrw/idl/parser.py index 4ca347d..7965d58 100644 --- a/thriftrw/idl/parser.py +++ b/thriftrw/idl/parser.py @@ -235,7 +235,7 @@ def p_service(self, p): p[0] = ast.Service( name=p[2], functions=p[6], - parent=p[4], + parent=ast.ServiceReference(p[4], p.lineno(4)), annotations=p[8], lineno=p.lineno(2), ) diff --git a/thriftrw/loader.py b/thriftrw/loader.py index b714b2f..2396350 100644 --- a/thriftrw/loader.py +++ b/thriftrw/loader.py @@ -23,7 +23,6 @@ import inspect import os.path -from .idl import Parser from .compile import Compiler from .protocol import BinaryProtocol @@ -31,7 +30,7 @@ class Loader(object): """Loads and compiles Thrift files.""" - __slots__ = ('parser', 'compiler', 'compiled_modules') + __slots__ = ('compiler',) def __init__(self, protocol=None, strict=True): """Initialize a loader. @@ -49,14 +48,8 @@ def __init__(self, protocol=None, strict=True): compatibility with existing Thrift files. """ protocol = protocol or BinaryProtocol() - - self.parser = Parser() self.compiler = Compiler(protocol, strict=strict) - # Mapping of absolute file path to compiled module. This is used to - # cache the result of calling load() multiple times on the same file. - self.compiled_modules = {} - def loads(self, name, document): """Parse and compile the given Thrift document. @@ -65,38 +58,27 @@ def loads(self, name, document): :param str document: The Thrift IDL as a string. """ - program = self.parser.parse(document) - return self.compiler.compile(name, program) + return self.compiler.compile(name, document).link().surface - def load(self, path, name=None, force=False): + def load(self, path, name=None): """Load and compile the given Thrift file. - If the file was already compiled before, a cached copy of the compiled - module is returned. - :param str path: Path to the ``.thrift`` file. :param str name: Name of the generated module. Defaults to the base name of the file. - :param bool force: - Whether to ignore the cache and load the file anew. Defaults to - False. :returns: The compiled module. """ - path = os.path.abspath(path) - if path in self.compiled_modules and not force: - return self.compiled_modules[path] - if name is None: - # TODO do we care if the file extension is .thrift? name = os.path.splitext(os.path.basename(path))[0] + # TODO do we care if the file extension is .thrift? + with open(path, 'r') as f: document = f.read() - module = self.loads(name, document) - self.compiled_modules[path] = module - return module + + return self.compiler.compile(name, document, path).link().surface _DEFAULT_LOADER = Loader(protocol=BinaryProtocol()) diff --git a/thriftrw/spec/service.pyx b/thriftrw/spec/service.pyx index 5a2a4e8..d959a23 100644 --- a/thriftrw/spec/service.pyx +++ b/thriftrw/spec/service.pyx @@ -357,12 +357,9 @@ class ServiceSpec(object): self.linked = True if self.parent is not None: - if self.parent not in scope.service_specs: - raise ThriftCompilerError( - 'Service "%s" inherits from unknown service "%s"' - % (self.name, self.parent) - ) - self.parent = scope.service_specs[self.parent].link(scope) + self.parent = scope.resolve_service_spec( + self.parent.name, self.parent.lineno + ) self.functions = [func.link(scope) for func in self.functions] self.surface = service_cls(self, scope)